1、线性层linear
1、官方简介
简单解释:下图中的input layer为in_features, hidden layer为out_features, 经过线性层将x转换为g
2、代码
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
#导入数据集
datasets = torchvision.datasets.CIFAR10("E:/PycharmProjects/Pytoch_learning/dataset/CIFAR10", train=False,transform=torchvision.transforms.ToTensor())
#加载数据集
dataloader = DataLoader(datasets,batch_size=64)
#构建模型
class Tian(nn.Module):
def __init__(self):
super(Tian, self).__init__()
#in_features=196608, out_features=10
self.linear = Linear(196608, 10)
def forward(self, input):
output = self.linear(input)
return output
#应用模型
ren = Tian()
for data in dataloader:
imgs,tagets = data
print(imgs.shape)#输出([64,3,32,32])
#改变尺寸
# output = torch.reshape(imgs, (1, 1, 1, -1))
#使用flatten
output = torch.flatten(imgs)
# print("After reshape:", output.shape)
print("After flatten:", output.shape)
#应用模型
output_linear = ren(output)
print("使用linear后:", output_linear.shape)
说明:关于CIFAR数据集,本人已下载好,若未下载请参考新手数据集下载
3、结果分析
直接加载的图片尺寸:torch.Size()
使用torch.reshape后输出尺寸:After reshape:
使用torch.flatten后输出尺寸:After flatten:
经由Linear layer,其中in_features=196608, out_features=10输出尺寸:使用linear后
2、torch.flatten()
其实就是“拉平”的意思,代码在1中