参考 torch.utils.checkpoint - 云+社区 - 腾讯云
注意:
在反向传播期间通过对每个检查分割运行一个前向传递分割来实现。这可能导致RNG状态等持久状态比没有检查点时更高级。默认情况下,检查点包含切换RNG状态的逻辑,这样使用RNG(例如通过dropout)的检查点通过与非检查点通过相比具有确定性的输出。根据检查点操作的运行时间,存储和恢复RNG状态的逻辑可能会导致适度的性能下降。如果不需要与非检查点传递相比的确定性输出,则向检查点或checkpoint_sequential提供preserve_rng_state=False,以省略每个检查点期间的RNG状态的存储和恢复。
存储逻辑将当前设备的RNG状态和所有cuda张量参数的设备保存并恢复到run_fn。但是,逻辑无法预测用户是否将张量移动到run_fn本身内的新设备。因此,如果您将张量移动到run_fn内的一个新设备(“new”表示不属于[当前设备+张量参数设备]的集合),那么与非checkpoint传递相比,确定性输出是无法保证的。
torch.utils.checkpoint.
checkpoint
(function, *args, **kwargs)[source]
检查模型或者模型的一部分。通过将计算变为内存来进行检查点工作。而不是存储用来计算反向传播的整个计算图的中间激活,检查部分不会保存在中间激活中,而是在反向传递中计算它们。它能应用到模型的任何一部分、特别地,在前向传播,函数将以torch.no_grad()方式运行,不存储中间激活。作为替代,前向传递保存输入元组和函数参数。在反向传递中,保存的函数和输入将会被恢复,并且前向传递在函数中再一次计算,现在跟踪中间激活,然后使用这些激活值来计算梯度。
警告:
检查点不和torch.autograd.grad()一起工作,但是仅仅和torch.autograd.backward()一起。
警告:
如果向后的函数调用与向前的函数调用有任何不同,例如,由于一些全局变量,检查点版本将不相等,不幸的是,它不能被检测到。
警告:
如果检查点段包含由detach()或torch.no_grad()从计算图中分离出来的张量,则向后传递将引发错误。这是因为检查点使得所有输出都需要梯度,当一个张量被定义为在模型中没有梯度时,就会产生问题。要绕过这个问题,可以将张量分离到检查点函数之外。
参数:
-
function –
描述模型或者部分的模型前行传递运行什么。
它还应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过(激活,隐藏),函数应该正确地使用第一个输入作为激活,第二个输入作为隐藏。 -
preserve_rng_state (bool, optional, default=True) – 在每个检查点期间省略RNG状态的存储和恢复。
-
args – 包含函数输入的元组
返回值:
- **args上运行函数的输出。
torch.utils.checkpoint.
checkpoint_sequential
(functions, segments, input, **kwargs)[source]
用于检查点顺序模型的辅助函数。顺序模型按顺序(顺序)执行一列模块/功能。因此,我们可以将该模型划分为各个分段和每个分段的检查点。除最后一个段外,所有段都将以torch.no_grad()方式运行,而不存储中间激活。每个检查点段的输入将被保存,以便在向后传递中重新运行该段。有关检查点是如何工作的,请参阅checkpoint()。
警告:
检查点不能和torch.autograd.grad()一起使用,但是仅仅和torch.autograd.backward()一起使用。
参数:
-
functions – 一个按顺序运行的torch.nn.Sequential或模块或函数(组成模型)的列表。
-
segments – 要在模型中创建的块的数量
-
input –
函数的输出张量
-
preserve_rng_state (bool, optional, default=True) – 在每个检查点期间省略RNG状态的存储和恢复。
返回值:
- 按*input顺序运行函数的输出。
例:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)