0
点赞
收藏
分享

微信扫一扫

Lnton羚通视频分析算法平台【PyTorch】教程:构建模型基础知识

BUILD THE NEURAL NETWORK (构建神经网络)

神经网络由 layers/modules 组成,torch.nn 提供了所有的你需要构建自己的神经网络的 blocks , 每个 module 都在 PyTorch 子类 nn.Module 找到。神经网络本身就是一个 module , 由其他的 modules (layers) 组成,这种嵌套的结构允许轻松的构建和管理复杂的框架结构。

在下面的部分中,我们将构建一个神经网络来对 FashionMNIST 数据集中的图像进行分类。

Get Device for Training

如果 GPU 可用的话,我们希望可以用 GPU 训练自己的模型。用 torch.cuda 判断是否可用,否则就使用 CPU 。

import os 
import torch 
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {device}")

Define the Class (定义神经网络类)

我们定义一个神经网络继承 nn.Module, 用 __init__ 初始化神经网络层,每一个 nn.Module 子类都实现了对数据进行前向推理的操作。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLu(),
            nn.Linear(512, 512),
            nn.ReLu(),
            nn.Linear(512, 10)
        )
  
    def forward(self, x):
        out = self.flatten(x)
        out = self.linear_relu_stack(out)
        return out

创建一个 NeuralNetwork 的实例,并且将它移动到 GPU 上,打印它的网络结构。

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

为了使用该模型,我们将输入数据传递给它,这将执行模型的 forward, 以及一些后台操作,不要直接调用 model.forward() 方法。

在输入上调用模型返回一个 2D tensor, dim=0 对应每个类的 10 个原始预测值的每个输出, dim=1 对应于每个输出的单个值,通过 nn.Softmax module 获得预测概率。

X = torch.rand(1, 28, 28, device=device)
logits = model(X)
pred_prob = nn.Softmax(dim=1)(logits)
y_pred = pred_prob.argmax(dim=1, keepdim=True)
print(f"Pred class is: 	{y_pred}")

Lnton 羚通专注于音视频算法、算力、云平台的高科技人工智能, 基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持 ONVIF、RTSP、GB/T28181 等多协议、多路数的音视频智能分析服务器。

Lnton羚通视频分析算法平台【PyTorch】教程:构建模型基础知识_神经网络

举报

相关推荐

0 条评论