admin 管理员组文章数量: 1184232
问题提出
在图像、模型压缩算法中往往涉及量化的操作。即将无限、连续的变量映射到有限、离散的空间中,方便存储和进一步计算。比如round操作:
import torch
a = torch.randint(1, 100, (3,)).float()
a.requires_grad = True
print(a) # tensor([19., 36., 46.], requires_grad=True)
# 添加随机噪声
noise = torch.rand(a.shape, requires_grad=True)
b = a + noise
print(b) # tensor([19.2860, 36.9746, 46.6897], grad_fn=<AddBackward0>)
# 对b进行四舍五入得到r_hat
r_hat = torch.round(b)
print(r_hat) # tensor([19., 37., 47.], grad_fn=<RoundBackward0>)
最后一步round()操作在反向传播的过程中是病态的, 因为round()函数的导数除了在x.500为inf外处处为0(x为任意整数)
解决
STE梯度近似
既然r_hat的梯度不可求,那么不妨在反向传播的过程中找一个连续可导的函数,用它的梯度替代round()函数的梯度,而在forward pass中仍然输出四舍五入后的结果。在实际应用中这个替代函数通常选择r=b,以下为了直观,以r=2*b为例。
"""In Pytorch"""
r_diffable = 2*b + (torch.round(b) - 2*b).detach()
r_diffable.retain_grad
print(r_diffable) # =torch.round(b) tensor([19., 37., 47.], grad_fn=<AddBackward0>)
"""In tensorflow"""
r_diffable = 2*b + tf.stop_gradient(tf.round(b) - 2*b)
测试一下梯度是否替换成功
loss = r_diffable.sum()
loss.backward()
# dL/db = dL/dr_hat * dr_hat/db
# dL/dr_hat = 1., dr_hat/db = dr_diffable/db = 2.
# dL/db = 2
print(b.grad) # tensor([2., 2., 2.])
版权声明:本文标题:Pytorch笔记:STE(Straight Through Estimator)解决forward pass中的non-differentiable操作,以torch.round()为例 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.roclinux.cn/b/1766361203a3452606.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论