0
点赞
收藏
分享

微信扫一扫

DRM全解析 —— CRTC详解(1)

pipu 2023-10-10 阅读 10

在这里插入图片描述

pytorch实现

x1 = torch.tensor(3.0, requires_grad=True)
y1 = torch.tensor(2.0, requires_grad=True)
a = x1 ** 2
b = 3 * a
c = b * y1
c.backward()
print(x1.grad)
print(y1.grad)
print(x1.grad == 6 * x1 * y1)
print(y1.grad == 3 * (x1 ** 2))

输出为:
tensor(36.)
tensor(27.)
tensor(True)
tensor(True)

默认情况下,pytorch会累加梯度,每次backward()前,需要进行梯度清零

x.grad.zero_()

在这里插入图片描述

举报

相关推荐

0 条评论