0
点赞
收藏
分享

微信扫一扫

Pytorch:torch.utils.checkpoint()

孟佳 03-21 14:00 阅读 2

在PyTorch中,torch.utils.checkpoint 模块提供了实现梯度检查点(也称为checkpointing)的功能。这个技术主要用于训练时内存优化,它允许我们以计算时间为代价,减少训练深度网络时的内存占用。

原理

梯度检查点技术的基本原理是,在前向传播的过程中,并不保存所有的中间激活值。相反,它只保存一部分关键的激活值在反向传播时,根据保留的激活值重新计算丢弃的中间激活值。因此内存的使用量会下降,但计算量会增加,因为需要重新计算一些前向传播的部分。

用法

torch.utils.checkpoint 中主要的函数是 checkpoint。checkpoint 函数可以用来封装模型的一部分或者一个复杂的运算,这部分会使用梯度检查点。它的一般用法是:

import torch
from torch.utils.checkpoint import checkpoint

# 定义一个前向传播函数
def custom_forward(*inputs):
    # 定义你的前向传播逻辑
    # 例如: x, y = inputs; result = x + y
    ...
    return result

# 在训练的前向传播过程中使用梯度检查点
model_output = checkpoint(custom_forward, *model_inputs)

在每次调用 custom_forward 函数时,它都会返回正常的前向传播结果。不过,checkpoint 函数会确保仅保留必须的激活值(即 custom_forward 的输出)。其他激活值不会保存在内存中,需要在反向传播时重新计算。

下面是一个具体的示例,演示了如何在一个简单的模型中使用 checkpoint 函数:

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class SomeModel(nn.Module):
    def __init__(self):
        super(SomeModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)

    def forward(self, x):
        # 使用checkpoint来减少第二层卷积的内存使用量
        x = self.conv1(x)
        x = checkpoint(self.conv2, x)
        return x

model = SomeModel()
input = torch.randn(1, 1, 28, 28)
output = model(input)
loss = output.sum()
loss.backward()

在上面的例子中,conv2的前向计算是通过 checkpoint 封装的,这意味着在 conv1 的输出和 conv2 的输出之间的激活值不会被完全存储。在反向传播时,这些丢失的激活值会通过再次前向传递 conv2 来重新计算。
使用梯度检查点技术可以在训练大型模型时减少显存的占用,但由于在反向传播时额外的重新计算,它会增加一些计算成本。

举报

相关推荐

0 条评论