admin 管理员组

文章数量: 1184232

问题提出

在图像、模型压缩算法中往往涉及量化的操作。即将无限、连续的变量映射到有限、离散的空间中,方便存储和进一步计算。比如round操作:

import torch
a = torch.randint(1, 100, (3,)).float()
a.requires_grad = True
print(a)	# tensor([19., 36., 46.], requires_grad=True)

# 添加随机噪声
noise = torch.rand(a.shape, requires_grad=True)
b = a + noise
print(b)	# tensor([19.2860, 36.9746, 46.6897], grad_fn=<AddBackward0>)

# 对b进行四舍五入得到r_hat
r_hat = torch.round(b)
print(r_hat)	# tensor([19., 37., 47.], grad_fn=<RoundBackward0>)

最后一步round()操作在反向传播的过程中是病态的, 因为round()函数的导数除了在x.500为inf外处处为0(x为任意整数)

解决

STE梯度近似

既然r_hat的梯度不可求,那么不妨在反向传播的过程中找一个连续可导的函数,用它的梯度替代round()函数的梯度,而在forward pass中仍然输出四舍五入后的结果。在实际应用中这个替代函数通常选择r=b,以下为了直观,以r=2*b为例。

"""In Pytorch"""
r_diffable = 2*b + (torch.round(b) - 2*b).detach()
r_diffable.retain_grad
print(r_diffable) # =torch.round(b) tensor([19., 37., 47.], grad_fn=<AddBackward0>)

"""In tensorflow"""
r_diffable = 2*b + tf.stop_gradient(tf.round(b) - 2*b)

测试一下梯度是否替换成功

loss = r_diffable.sum()
loss.backward()
# dL/db = dL/dr_hat * dr_hat/db
#  dL/dr_hat = 1., dr_hat/db = dr_diffable/db = 2.
# dL/db = 2
print(b.grad)	# tensor([2., 2., 2.])

本文标签: 为例 操作 笔记 straight STE