Pytorch ——基础指北_伍
系列文章目录
Pytorch ——基础指北_零
Pytorch ——基础指北_壹
Pytorch ——基础指北_贰
Pytorch ——基础指北_叁
文章目录
软件环境:
- pytorch 1.10
- pycharm
配套代码下载地址:
gitee-pytorch
基础知识:
本次混讲入过程。
如何通过官方数据集构建自己的数据集?
MNIST数据集
MNIST数据库是非常经典的一个数据集,就像你学编程起初写一个“Hello Word”的程序一样,学Deep Learning你就会写识别MNIST数据集的Model。
MNIST数据集是由0〜9手写数字图片和数字标签所组成的,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。
如何使用MNIST数据集?
在Pytorch中你可以很方便的使用它,只需要下载后,就可以立即使用,下载方法很简单,首先引用针对MNIST数据集的类:from torchvision.datasets import MNIST
。
创建这个类的实例MNIST_dataset = MNIST(root="./data",train=True,download=True,transform=None)
。
其中参数分别是
- root 存储位置;
- train 是否为训练集;
- download是否下载;
- transform处理函数,下一个节点会具体介绍其作用;
下载完成后我们可以输出一个看看一看是什么:
print(MNIST_dataset[0])
可以看到由两部分组成:
图片+一个数字,这个数字就是标记好的数据集的label,指明了本张图片的内容。
值得注意的是MINST_dataset本身就是一个data包下的Dataset类,你从他的继承关系就可以看到:
父类就是继承于Dataset。
Transforms类的数据处理方法
torchvision.transforms.ToTensor
ToTensor是Pytorch内置的数据处理方法,常规的图像数据通常是这样的:
[H,W,C]
- H代表了图像高度,也是H维度的大小;
- W代表了图像宽度,也是W维度的大小;
- C代表的是通道数,也就是每个像素点后的颜色通道大小,通常就是RGB,也就是3,你可以理解为每个像素点都是由这三个颜色在0-255之间混合起来的。
回到我们这个数据集来说,每张图片其实本身是黑白的,也就是说通道数只有一,每张图片的结构应该是 [28,28,1] 。
但是对于Pytorch很多处理函数来说并不支持这样的格式,而是支持 [C,H,W] 的格式,这就麻烦了,我们就需要做维度变换(视角变换)来变换数据格式,为什么这样说的呢?因为本身来说数据是固定的,我们只不过换了一种描述方式来对其进行新的描述,那么我们怎么把数据变换成 [1,28,28] 呢?很简单调用ToTensor这个方法就可以,我们把这个方法传入transforms,transforms会自动调用ToTensor来进行处理。
但是这里还是要讨论一下如果不是支持的格式,我们该怎么使用pytorch进行手动变换呢?
这就要介绍其内置的变换函数,具体使用规则可以参考第一篇所讲解的用法:基础指北_壹,这里直接讲解理论。
图像的维度变换
对于常规图像来说都是遵循这样的格式,[W,H,C] 或者 [H,W,C] 但这样并不是通用的格式,便于理解但是不便于计算,如果我们设想有这样一张图片他的格式是这样的,[2,2,3]也就是一张2x2的图片,每个像素点三个通道分别是RGB,
把三个通道分解开就是这样:
加上维度坐标系就是这样:
我们将数据填补进去,每个颜色在像素点的大小都是在0-255之间,这里就用[1-12]来替代看得更清楚:
你可能会好奇,如果使用**[W,H,C]**怎么表示这个图片呢?
可以展示一下:
>>>t1
>>>tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
以最内的逗号最为z轴的分割,次外层作为y轴的分割,最外层的逗号作为x轴的分割,你可以看到:
可以看出来,最里边的括号就是代表一个像素点的三个通道,对于下标(0,0)的点,他三个通道的值就是1,2,3.从上图也可以看出。
那为什么说轴变换就是视角变换呢?
我们举个例子,在目前的情况下我们将图像从**[W,H,C]变换到[C,H,W]**,图像的样子会发生什么?
我们使用permute函数进行变换,T1的结果如下:
>>>t1= t1.permute(2,1,0)#交换 0 和 2 维度
>>>t1
>>>tensor([[[ 1, 7],
[ 4, 10]],
[[ 2, 8],
[ 5, 11]],
[[ 3, 9],
[ 6, 12]]])
我们同样以最内的逗号最为z轴的分割,次外层作为y轴的分割,最外层的逗号作为x轴的分割,再来把这样图画出来:
你可以发现坐标轴样子没有变,而代表轴的意思不同的了,相当于就是同一块数据使用了不同的视角来进行描述。
如果我们回到二维,你会发现这个操作其实就是矩阵的转置,而目前这个操作就可以叫三维矩阵的转置。
最后补一张思考时候的手稿。
总结一句话:
ToTensor方法就是帮我们实现了 ”维度变换“ 这个操作。
那么如何使用ToTensor方法呢?
使用方法很简单,直接通过其类的call方法传入参数即可
举例:
from torchvision import transforms
MNIST_dataset = MNIST(root="./data", train=True, download=True, transform=None)
MNIST_tensor = transforms.ToTensor()(MNIST_dataset[0][0])
print(MNIST_tensor)
就可以打印出来刚才的图片5。当然实际上你看不出来这一堆数据到底相容的什么。你只要知道他是对
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,
0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,
0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,
0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,
0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,
0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,
0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,
0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,
0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,
0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,
0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,
0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,
0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,
0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,
0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,
0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,
0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000]]])
Process finished with exit code 0
torchvision.transforms.Normalize(mean ,std)
你可以看到这个方法也是在transforms包下的,说明他也是用来对数据集进行处理的预置方法,这个方法的作用很简单,就是标准化图像里面的数据,使用的时候只需要传入对应的mean(均值),std(标准差)即可,值得注意的是,标准差和方差都指的是图像每个通道的标准差与均值,假如RGB图像,你就应该传入一个(三位的list[a,b,c])。
那么什么是标准化呢?
图像标准化是将数据通过去均值实现中心化的处理,根据凸优化理论与数据概率分布相关知识,数据中心化符合数据分布规律,更容易取得训练之后的泛化效果, 数据标准化是数据预处理的常见方法之一
有公式:
N
o
r
m
a
i
l
i
z
e
d
i
m
a
g
e
=
(
i
m
a
g
e
−
m
e
a
n
)
S
t
d
其
中
:
S
t
d
=
σ
2
=
∑
i
=
1
N
(
x
i
−
μ
)
2
N
(
X
i
是
某
像
素
的
某
一
通
道
值
)
Normailized_{image} = \frac{(image - mean ) }{Std}\\ 其中:\\ Std =\sqrt{\sigma^{2}}=\sqrt{\frac{\sum_{i=1}^{N}\left(x_{i}-\mu\right)^{2}}{N}}(X_i是某像素的某一通道值)
Normailizedimage=Std(image−mean)其中:Std=σ2=N∑i=1N(xi−μ)2(Xi是某像素的某一通道值)
这一大波操作看的头皮发麻,实际上有一个更常用的公式:
你应该瞬间就明白了,就是把图像的数据分布转换成一个均值为0,标准差为1的分布。但是值得注意的是并不是一定为正态分布,取决于原来数据分布的情况。
使用方法同样也是调用call方法也就是在实例化对象后直接()传入图像。
from torchvision.datasets import MNIST
import torchvision
from torchvision import transforms
MNIST_dataset = MNIST(root="./data", train=True, download=True, transform=None)
MNIST_tensor = transforms.ToTensor()(MNIST_dataset[0][0])
MNIST_tensor_nom = transforms.Normalize(10, 10)(MNIST_tensor)
print(MNIST_tensor_nom)
值得注意的是,这里的(10, 10)并不是真正的标准差和均值,是我随意写的,主要是用来验证使用方法而已。
torchvision.transforms.Compose(transforms)
这个就更简单了,将多个transform
组合起来使用。
例如:
transforms.Compose(
[torchvision.transforms.ToTensor(), #先转化为Tensor
torchvision.transforms.Normalize(mean,std)] #在进行正则化
)
可以看到传入的参数实际上是一个由多个transforms包下的处理类组成的List。
知道了这样,我们就可以改写一下上边所进行两个操作:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
transforms_fn = Compose([ToTensor(), Normalize(mean=(0.1307,), std=(0.3081,))])
MNIST_dataset = MNIST(root="./data", train=True, download=True, transform=transforms_fn)
print(MNIST_dataset)
你如果看过上一篇关于数据集类的介绍,就知道Dataset类都是用来初始化数据集的,并不能对数据集做分批,打乱,等操作,所以对一个完整的数据集来说应该配合有对应的Dataloader类来进一步对数据集操作,并将数据送入网络。
结合Dataloader进行操作:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
transforms_fn = Compose([ToTensor(), Normalize(mean=(0.1307,), std=(0.3081,))])
MNIST_dataset = MNIST(root="./data", train=True, download=True, transform=transforms_fn)
#构建dataloader
MNIST_dataloader = DataLoader(dataset=MNIST_dataset, batch_size=2, drop_last=True, shuffle=True)
这样基本的一个数据集就构建完成了,当然我们可以给他增加一些功能并整合成一个函数,可以通过MNIST_dataset的train参数来设置是训练集还是测试集,只要传入一个参数即可,最后让他返回制作好的Dataloader。
def MNIST_create_dataloader(train=True):
transforms_fn = Compose([ToTensor(), Normalize(mean=0.1307, std=0.3081)])
MNIST_dataset = MNIST(root="./data", train=train, download=True, transform=transforms_fn)
MNIST_dataloader = DataLoader(dataset=MNIST_dataset, batch_size=2, drop_last=True, shuffle=True)
return MNIST_dataloader
我们来遍历一下制作好的Dataloader,完整代码如下:
import sys
import torch
import torch.nn as nn
import torch.optim as opt
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
def MNIST_create_dataloader(train=True):
transforms_fn = Compose([ToTensor(), Normalize(mean=0.1307, std=0.3081)])
MNIST_dataset = MNIST(root="./data", train=train, download=True, transform=transforms_fn)
MNIST_dataloader = DataLoader(dataset=MNIST_dataset, batch_size=2, drop_last=True, shuffle=True)
return MNIST_dataloader
def get_MNIST_dataloader():
MNIST_dataloader = MNIST_create_dataloader(False)
for num, i in enumerate(MNIST_dataloader):
print(num, i)
get_MNIST_dataloader()
只看最后一batch的输出结果:
因为batch_size=2,每组就是由两张照片所组成的。
29999 [tensor([[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]],
[[[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
...,
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, ..., -0.4242, -0.4242, -0.4242]]]]), tensor([5, 8])]