0
点赞
收藏
分享

微信扫一扫

电巢科技出席第26届西北地区电子技术与线路课程教学改革研讨会,聚焦一流课程建设!

数数扁桃 2023-09-21 阅读 15

        

目录

一、前言

二、实验环境

三、PyTorch数据结构

0、分类

1、张量(Tensor)

2、张量操作(Tensor Operations)

3、变量(Variable)

4、数据集(Dataset)

5、数据加载器(DataLoader)

6、模块(Module)


一、前言

ChatGPT:

二、实验环境

        本系列实验使用如下环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib

关于配置环境问题,可参考前文的惨痛经历:

三、PyTorch数据结构

0、分类

  • Tensor(张量):Tensor是PyTorch中最基本的数据结构,类似于多维数组。它可以表示标量、向量、矩阵或任意维度的数组。
  • Tensor的操作:PyTorch提供了丰富的操作函数,用于对Tensor进行各种操作,如数学运算、统计计算、张量变形、索引和切片等。这些操作函数能够高效地利用GPU进行并行计算,加速模型训练过程。
  • Variable(变量):Variable是对Tensor的封装,用于自动求导。在PyTorch中,Variable会自动跟踪和记录对其进行的操作,从而构建计算图并支持自动求导。在PyTorch 0.4.0及以后的版本中,Variable被废弃,可以直接使用Tensor来进行自动求导。
  • Dataset(数据集):Dataset是一个抽象类,用于表示数据集。通过继承Dataset类,可以自定义数据集,并实现数据加载、预处理和获取样本等功能。PyTorch还提供了一些内置的数据集类,如MNIST、CIFAR-10等,用于方便地加载常用的数据集。
  • DataLoader(数据加载器):DataLoader用于将Dataset中的数据按批次加载,并提供多线程和多进程的数据预读功能。它可以高效地加载大规模的数据集,并支持数据的随机打乱、并行加载和数据增强等操作。
  • Module(模块):Module是PyTorch中用于构建模型的基类。通过继承Module类,可以定义自己的模型,并实现前向传播和反向传播等方法。Module提供了参数管理、模型保存和加载等功能,方便模型的训练和部署。

1、张量(Tensor

        

PyTorch数据结构:1、Tensor(张量):维度(Dimensions)、数据类型(Data Types)_QomolangmaH的博客-CSDN博客https://blog.csdn.net/m0_63834988/article/details/132909219​编辑https://blog.csdn.net/m0_63834988/article/details/132909219icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132909219

2、张量操作(Tensor Operations)

3、变量(Variable)

PyTorch数据结构:3、变量(Variable)介绍_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132923358?spm=1001.2014.3001.5501

4、数据集(Dataset)

PyTorch数据结构:4、数据集(Dataset)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132923697?spm=1001.2014.3001.5501

5、数据加载器(DataLoader)

PyTorch数据结构:5、数据加载器(DataLoader)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132924381?spm=1001.2014.3001.5502

6、模块(Module)

        在PyTorch中,构建模型的一般步骤是创建一个自定义的模型类,并将其作为Module的子类。在模型类的初始化方法中(通常为__init__),可以定义模型的层结构和参数。通过重写前向传播方法(forward),可以定义数据在模型中的流动方式。

        Module类提供了许多有用的功能:

  •  参数管理:Module类自动跟踪和管理模型中的可学习参数。这些参数可以通过parameters()方法获得,方便进行优化和参数更新操作。

  • 模型保存和加载:可以使用torch.save()方法将整个模型保存到文件中,以便在以后重新加载和使用。加载模型时,可以使用torch.load()方法加载保存的模型参数。

  • 自动求导:Module类内部的操作默认支持自动求导。在前向传播过程中,PyTorch会自动构建计算图,并记录每个操作的梯度计算方式。这样,在反向传播过程中,可以自动计算和更新模型的参数梯度。 

        除了上述功能,Module类还提供了一些其他方法和属性,例如train()eval()方法用于切换模型的训练和评估模式,to()方法用于将模型移动到指定的设备(如CPU或GPU)等。

import torch
import torch.nn as nn

# 自定义模型类
class CustomModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(CustomModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# 创建一个自定义模型的实例
input_size = 10
hidden_size = 20
num_classes = 2

model = CustomModel(input_size, hidden_size, num_classes)

# 使用模型进行前向传播
input_data = torch.randn(5, input_size)
output = model(input_data)

print(output)

# 保存模型
torch.save(model.state_dict(), 'model.pth')

# 加载模型
loaded_model = CustomModel(input_size, hidden_size, num_classes)
loaded_model.load_state_dict(torch.load('model.pth'))
  • 定义了一个自定义的模型类CustomModel,继承自nn.Module
  • 在模型的初始化方法中定义了模型的层结构,并在前向传播方法中定义了数据的流动方式。
  • 通过创建模型实例model,我们可以使用它进行前向传播。
  • 保存模型参数到文件中,并在之后加载参数来恢复模型。

  

举报
0 条评论