0
点赞
收藏
分享

微信扫一扫

torch.utils.checkpoint

参考 torch.utils.checkpoint - 云+社区 - 腾讯云

注意:

在反向传播期间通过对每个检查分割运行一个前向传递分割来实现。这可能导致RNG状态等持久状态比没有检查点时更高级。默认情况下,检查点包含切换RNG状态的逻辑,这样使用RNG(例如通过dropout)的检查点通过与非检查点通过相比具有确定性的输出。根据检查点操作的运行时间,存储和恢复RNG状态的逻辑可能会导致适度的性能下降。如果不需要与非检查点传递相比的确定性输出,则向检查点或checkpoint_sequential提供preserve_rng_state=False,以省略每个检查点期间的RNG状态的存储和恢复。

存储逻辑将当前设备的RNG状态和所有cuda张量参数的设备保存并恢复到run_fn。但是,逻辑无法预测用户是否将张量移动到run_fn本身内的新设备。因此,如果您将张量移动到run_fn内的一个新设备(“new”表示不属于[当前设备+张量参数设备]的集合),那么与非checkpoint传递相比,确定性输出是无法保证的。

torch.utils.checkpoint.checkpoint(function, *args, **kwargs)[source]

检查模型或者模型的一部分。通过将计算变为内存来进行检查点工作。而不是存储用来计算反向传播的整个计算图的中间激活,检查部分不会保存在中间激活中,而是在反向传递中计算它们。它能应用到模型的任何一部分、特别地,在前向传播,函数将以torch.no_grad()方式运行,不存储中间激活。作为替代,前向传递保存输入元组和函数参数。在反向传递中,保存的函数和输入将会被恢复,并且前向传递在函数中再一次计算,现在跟踪中间激活,然后使用这些激活值来计算梯度。

警告:

检查点不和torch.autograd.grad()一起工作,但是仅仅和torch.autograd.backward()一起。

警告:

如果向后的函数调用与向前的函数调用有任何不同,例如,由于一些全局变量,检查点版本将不相等,不幸的是,它不能被检测到。

警告:

如果检查点段包含由detach()或torch.no_grad()从计算图中分离出来的张量,则向后传递将引发错误。这是因为检查点使得所有输出都需要梯度,当一个张量被定义为在模型中没有梯度时,就会产生问题。要绕过这个问题,可以将张量分离到检查点函数之外。

参数:

  • function描述模型或者部分的模型前行传递运行什么。它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过(激活,隐藏),函数应该正确地使用第一个输入作为激活,第二个输入作为隐藏。

  • preserve_rng_state (bool, optional, default=True) – 在每个检查点期间省略RNG状态的存储和恢复。

  • args – 包含函数输入的元组

返回值:

  • **args上运行函数的输出。

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, **kwargs)[source]

用于检查点顺序模型的辅助函数。顺序模型按顺序(顺序)执行一列模块/功能。因此,我们可以将该模型划分为各个分段和每个分段的检查点。除最后一个段外,所有段都将以torch.no_grad()方式运行,而不存储中间激活。每个检查点段的输入将被保存,以便在向后传递中重新运行该段。有关检查点是如何工作的,请参阅checkpoint()。

警告:

检查点不能和torch.autograd.grad()一起使用,但是仅仅和torch.autograd.backward()一起使用。

参数:

  • functions – 一个按顺序运行的torch.nn.Sequential或模块或函数(组成模型)的列表。

  • segments – 要在模型中创建的块的数量

  • input函数的输出张量

  • preserve_rng_state (bool, optional, default=True) – 在每个检查点期间省略RNG状态的存储和恢复。

返回值:

  • 按*input顺序运行函数的输出。

例:

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
举报

相关推荐

0 条评论