0
点赞
收藏
分享

微信扫一扫

pytorch 梯度相关知识点

yundejia 2022-03-14 阅读 57

文章目录

1. requires_grad

如果需要为张量计算所需的梯度,那么我们就需要对张量设置requires_grad=True;张量创建的时候默认requires_grad=False

  • 如果不设置requires_grad=True,后续计算梯度的时候就会报错
    (1)requires_grad=False&默认设置
import torch
from torch import nn

# 创建一个输入x,默认设置
x = torch.ones(5)
# y = 2*x**2
y = 2*torch.dot(x,x)
# y 进行梯度返传
y.backward()
# 打印x的梯度,即x.grad
print(f"x.grad={x.grad}")
  • 结果
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

(2)requires_grad=False

import torch
from torch import nn

x_false = torch.ones(5, requires_grad=False)
y_false = 2 * torch.dot(x_false, x_false)
y_false.backward()
print(f"x_false.grad={x_false.grad}")
  • 结果
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

(3)requires_grad=True

import torch
from torch import nn


x_true = torch.ones(5, requires_grad=True)
y_true = 2 * torch.dot(x_true,x_true)
y_true.backward()
print(f"x_true.grad={x_true.grad}")
print(f"4*x_true={4*x_true}")
print(f"x_true.grad==4*x_true={x_true.grad==4*x_true}")
  • 结果:
x_true.grad=tensor([4., 4., 4., 4., 4.])
4*x_true=tensor([4., 4., 4., 4., 4.], grad_fn=<MulBackward0>)
x_true.grad==4*x_true=tensor([True, True, True, True, True])

2. grad_fn,grad

grad:表示当执行完y.backward()后,可以通过x.grad计算x变量的梯度
grad_fn是用来记录变量是怎么来的,记录图节点的方式,为了后续反向传播做准备
z = 2 ∗ x 2 + 6 z=2*x^2+6 z=2x2+6
由上述公式可得:

  • x:最底层的生物,牛马如我;故x.grad_fn=None
  • y = 2 ∗ x 2 y=2*x^2 y=2x2:来源于乘法,故y.grad_fn = MulBackward
  • z = y + 6 z=y+6 z=y+6:来源于加法,故z.grad_fn = AddBackward
x_true = torch.ones(5, requires_grad=True)
y_true = 2 * torch.dot(x_true, x_true)
z_true = y_true + 6
z_true.backward()
print(f"x_true.grad={x_true.grad}")
print(f"x_true.grad_fn={x_true.grad_fn}")
print(f"y_true.grad_fn={y_true.grad_fn}")
print(f"z_true.grad_fn={z_true.grad_fn}")

结果:

x_true.grad=tensor([4., 4., 4., 4., 4.])
x_true.grad_fn=None
y_true.grad_fn=<MulBackward0 object at 0x00000180E0DF3550>
z_true.grad_fn=<AddBackward0 object at 0x00000180E0DF3550>

3. with torch.no_grad()

torch.no_grad
禁用梯度计算的上下文管理器;

4. torch.detach()

举报

相关推荐

Flutter相关知识点

MQ相关知识点

Stream相关知识点

HTML相关知识点

Hibernate相关知识点

JVM相关知识点

Modbus相关知识点

Spring相关知识点

0 条评论