【声明】来源于b站视频小土堆PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10('./data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset=dataset, batch_size=64)
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.conv1 = Conv2d(in_channels=3,
out_channels=6,
kernel_size=3,
padding=0,
stride=1)
def forward(self,x):
x = self.conv1(x)
return x
MyModel = Model()
for data in dataloader:
imgs,labels = data
output = MyModel(imgs)
print(imgs.shape)
print(output.shape)