Pytorch: 数据库介绍、数据操作和预处理
Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, School of Artificial and Intelligence, Huazhong University of Science and Technology
文章目录
本教程不商用,仅供学习和参考交流使用,如需转载,请联系本人。
常用数据库简介
ImageNet 数据集,有 1400 多万张图片,2 万多个类别,超过百万的有明确的标注。与其对应的是 ILSVRC ,国际计算机视觉挑战赛。很多 torch 上的预训练模型都是基于 ImageNet 训练得到的。
PASCAL VOC 数据集,Pattern Analysis, Statistical Modelling and Computational Learning. 常用的有 VOC 2007 和 VOC 2012 。VOC 数据集结构如下:
-
JPEGImages: 包含所有训练测试的图片。
-
Annotations: 存放 XML 格式的标签数据,每一个 XML 文件都对应 JPEGImages 的一张图片。
-
ImageSets: 对于物体检测只需要 Main 子文件夹,并在 Main 文件夹中建立 Trainval.txt, train.txt, val.txt 以及 test.txt,在各文件夹记录对应图片名即可。
COCO 数据集,Common Objects in Contexts,针对物体检测、分割、图像语义理解和人体关节点监测等,其数据集难度更大,小物体更多,物体大小的跨度也更大。COCO 也提供了多种语言的 API。
自动驾驶的数据集,比如说 KITTI, Cityscape, Udacity 等。
导入本节需要库:
import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision import transforms
from sklearn.datasets import load_boston, load_iris
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
数据加载
数据加载的过程主要分为 3 步:
继承 Dataset 类 -> 增加数据变换 -> 继承 Dataloader
继承 Dataset 类
torch.utils.data.Dataset 类的继承和改写:
class my_dataset(Data.Dataset):
def __init__(self, image_path, annotation_path, transform=None):
super(my_data, self).__init__()
# 初始化,读取数据集
def __len__(self):
# 获取数据集大小
pass
def __getitem__(self, id):
# 对于指定的 id, 读取数据并返回
pass
# 对于上述实例化
my_image_path = ...
my_annotation_path = ...
dataset = my_dataset(my_image_path, my_annotation_path)
for data in dataset:
print(data)
pass
图像数据增强
torchvision.transforms 可以方便地进行图像缩放、裁剪、随机翻转、填充和张量的归一化。操作对象是 PIL 的 Image 或者 Tensor。
如果需要多个变换,可以用 transforms.Compose 将多个变换组合。
dataset = my_dataset(my_image_path, my_annotation_path, transforms=
transforms.Compose([
# 重设大小为256*256
transforms.Resize(256),
# 以 0.5 的概率进行随机翻转
transforms.RandomHorizontalFlip(),
# 转为 Tensor, 从 [0, 255] 归一化到 [0, 1]
transforms.ToTensor(),
# 进行 mean 和 std 为 0.5 的标准化
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]))
继承 Dataloader
为了进行批量处理,和随机选取,还需要 torch.utils.data.Dataloader。
dataloader = Data.DataLoader(dataset = dataset, # 使用的数据
batch_size = 64, # 批量处理的样本大小
shuffle = True, # 每次都随机打乱数据
num_workers = 1, # 使用两个进程
)
for step, (b_x, b_y) in enumerate(dataloader):
# 将 b_x 用于训练网络,b_y 用于计算损失
pass
或者生成一个迭代器:
data_iter = iter(dataloader)
for step in range(iter_per_epoch):
data = next(data_iter)
# 将 data 用于训练网络
pass
数据预处理
在torch.util.data模块中包含一些常用的数据预处理操作
torch.util.data.TensorDataset() # 将数据处理为张量
torch.util.data.ConcatDataset() # 处理多个数据集
torch.util.data.Subset() # 根据索引获取数据集的子集
torch.util.data.DataLoader() # 数据加载器
torch.util.data.random_split() # 随机将数据集拆分为给定长度的非重叠新数据集
接下来我们用例子演示数据加载的过程:
高维数组
每个样本都有很多个预测变量(特征)和一个被预测变量(标签)
连续数据对应回归预测,离散数据对应分类问题。
使用sklearn提供的boston和iris数据分别做回归和分类的框架表示。
回归数据准备(boston房价回归)
# 加载sklearn中的boston回归数据集
boston_X, boston_y = load_boston(return_X_y = True) # data和label
print(boston_X.dtype)
print(boston_y.dtype)
float64
float64
boston_X
array([[6.3200e-03, 1.8000e+01, 2.3100e+00, ..., 1.5300e+01, 3.9690e+02,
4.9800e+00],
[2.7310e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9690e+02,
9.1400e+00],
[2.7290e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9283e+02,
4.0300e+00],
...,
[6.0760e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
5.6400e+00],
[1.0959e-01, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9345e+02,
6.4800e+00],
[4.7410e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
7.8800e+00]])
boston_y
array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21.5, 19.6, 15.3, 19.4,
17. , 15.6, 13.1, 41.3, 24.3, 23.3, 27. , 50. , 50. , 50. , 22.7,
25. , 50. , 23.8, 23.8, 22.3, 17.4, 19.1, 23.1, 23.6, 22.6, 29.4,
23.2, 24.6, 29.9, 37.2, 39.8, 36.2, 37.9, 32.5, 26.4, 29.6, 50. ,
32. , 29.8, 34.9, 37. , 30.5, 36.4, 31.1, 29.1, 50. , 33.3, 30.3,
34.6, 34.9, 32.9, 24.1, 42.3, 48.5, 50. , 22.6, 24.4, 22.5, 24.4,
20. , 21.7, 19.3, 22.4, 28.1, 23.7, 25. , 23.3, 28.7, 21.5, 23. ,
26.7, 21.7, 27.5, 30.1, 44.8, 50. , 37.6, 31.6, 46.7, 31.5, 24.3,
31.7, 41.7, 48.3, 29. , 24. , 25.1, 31.5, 23.7, 23.3, 22. , 20.1,
22.2, 23.7, 17.6, 18.5, 24.3, 20.5, 24.5, 26.2, 24.4, 24.8, 29.6,
42.8, 21.9, 20.9, 44. , 50. , 36. , 30.1, 33.8, 43.1, 48.8, 31. ,
36.5, 22.8, 30.7, 50. , 43.5, 20.7, 21.1, 25.2, 24.4, 35.2, 32.4,
32. , 33.2, 33.1, 29.1, 35.1, 45.4, 35.4, 46. , 50. , 32.2, 22. ,
20.1, 23.2, 22.3, 24.8, 28.5, 37.3, 27.9, 23.9, 21.7, 28.6, 27.1,
20.3, 22.5, 29. , 24.8, 22. , 26.4, 33.1, 36.1, 28.4, 33.4, 28.2,
22.8, 20.3, 16.1, 22.1, 19.4, 21.6, 23.8, 16.2, 17.8, 19.8, 23.1,
21. , 23.8, 23.1, 20.4, 18.5, 25. , 24.6, 23. , 22.2, 19.3, 22.6,
19.8, 17.1, 19.4, 22.2, 20.7, 21.1, 19.5, 18.5, 20.6, 19. , 18.7,
32.7, 16.5, 23.9, 31.2, 17.5, 17.2, 23.1, 24.5, 26.6, 22.9, 24.1,
18.6, 30.1, 18.2, 20.6, 17.8, 21.7, 22.7, 22.6, 25. , 19.9, 20.8,
16.8, 21.9, 27.5, 21.9, 23.1, 50. , 50. , 50. , 50. , 50. , 13.8,
13.8, 15. , 13.9, 13.3, 13.1, 10.2, 10.4, 10.9, 11.3, 12.3, 8.8,
7.2, 10.5, 7.4, 10.2, 11.5, 15.1, 23.2, 9.7, 13.8, 12.7, 13.1,
12.5, 8.5, 5. , 6.3, 5.6, 7.2, 12.1, 8.3, 8.5, 5. , 11.9,
27.9, 17.2, 27.5, 15. , 17.2, 17.9, 16.3, 7. , 7.2, 7.5, 10.4,
8.8, 8.4, 16.7, 14.2, 20.8, 13.4, 11.7, 8.3, 10.2, 10.9, 11. ,
9.5, 14.5, 14.1, 16.1, 14.3, 11.7, 13.4, 9.6, 8.7, 8.4, 12.8,
10.5, 17.1, 18.4, 15.4, 10.8, 11.8, 14.9, 12.6, 14.1, 13. , 13.4,
15.2, 16.1, 17.8, 14.9, 14.1, 12.7, 13.5, 14.9, 20. , 16.4, 17.7,
19.5, 20.2, 21.4, 19.9, 19. , 19.1, 19.1, 20.1, 19.9, 19.6, 23.2,
29.8, 13.8, 13.3, 16.7, 12. , 14.6, 21.4, 23. , 23.7, 25. , 21.8,
20.6, 21.2, 19.1, 20.6, 15.2, 7. , 8.1, 13.6, 20.1, 21.8, 24.5,
23.1, 19.7, 18.3, 21.2, 17.5, 16.8, 22.4, 20.6, 23.9, 22. , 11.9])
# 特征和标签转为张量
train_X = torch.from_numpy(boston_X.astype(np.float32))
train_y = torch.from_numpy(boston_y.astype(np.float32))
print(train_x.dtype)
print(train_y.dtype)
torch.float32
torch.float32
# 使用TensorDataset将X和y整理到一起
train_data = Data.TensorDataset(train_X, train_y)
# 定义数据加载器,将数据集批量切分为多个batch
train_loader = Data.DataLoader(dataset = train_data, # 使用的数据
batch_size = 64, # 批量处理的样本大小
shuffle = True, # 每次都随机打乱数据
num_workers = 1, # 使用两个进程
)
for step, (b_x, b_y) in enumerate(train_loader):
# 检查训练集一个batch样本的维度是否正确
if step > 0:
break
# 输出训练图像的尺寸和标签的尺寸及数据类型
print(b_x.shape)
print(b_y.shape)
print(b_x.dtype)
print(b_y.dtype)
torch.Size([64, 13])
torch.Size([64])
torch.float32
torch.float32
分类数据准备(iris鸢尾花分类)
# 加载sklearn中的iris分类数据集
iris_X, iris_y = load_iris(return_X_y = True) # data和label
print(iris_X.dtype)
print(iris_y.dtype)
float64
int32
iris_X
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.6, 1.4, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]])
iris_y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
# 特征和标签转为张量
train_X = torch.from_numpy(iris_X.astype(np.float32))
train_y = torch.from_numpy(iris_y.astype(np.int64))
print(train_X.dtype)
print(train_y.dtype)
torch.float32
torch.int64
# 使用TensorDataset将X和y整理到一起
train_data = Data.TensorDataset(train_X, train_y)
# 定义数据加载器,将数据集批量切分为多个batch
train_loader = Data.DataLoader(dataset = train_data, # 使用的数据
batch_size = 10, # 批量处理的样本大小
shuffle = True, # 每次都随机打乱数据
num_workers = 1, # 使用两个进程
)
for step, (b_x, b_y) in enumerate(train_loader):
# 检查训练集一个batch样本的维度是否正确
if step > 0:
break
# 输出训练图像的尺寸和标签的尺寸及数据类型
print(b_x.shape)
print(b_y.shape)
print(b_x.dtype)
print(b_y.dtype)
torch.Size([10, 4])
torch.Size([10])
torch.float32
torch.int64
图像数据
torchversion的dataset包含很多常用的分类数据集
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
从dataset模块中导入数据并预处理
以FashionMNIST为例,其包含一个60000张的28*28的灰度图作为训练集,10000张28*28的灰度图作为测试集。数据共十类,是服饰类的图像
# 使用FashionMNIST数据,准备训练集
train_data = FashionMNIST(root = './data/FashionMNIST', # 数据路径
train = True, # 只使用训练数据集
transform = transforms.ToTensor(),
# 把取值范围0-255的[H, W, C]图像数据转为0-1的[C, H, W]张量
download = False
# 下载成功后,若重新运行该代码块,就改成False
)
# 定义一个数据加载器,将数据集批量切分为多个batch
train_loader = Data.DataLoader(dataset = train_data, # 使用的数据集
batch_size = 64, # 待处理的样本大小
shuffle = True, # 每次迭代都随机打乱数据
num_workers = 2 # 使用两个进程
)
# 计算train_loader的batch数
print(len(train_loader))
938
# 使用FashionMNIST数据,准备测试集
test_data = FashionMNIST(root = './data/FashionMNIST', # 数据路径
train = False, # 不使用训练数据集
download = False
# 下载成功后,若重新运行该代码块,就改成False
)
# 为数据添加一个维度,并且取值范围归一化
# 用test_data.data获取图像数据
test_data_X = test_data.data.type(torch.FloatTensor) / 255.0
test_data_X = torch.unsqueeze(test_data_X, dim=1) # 在维度1上插入一个新维度
# 用test_data.targets获取图像标签
test_data_y = test_data.targets
print(test_data_X.shape)
print(test_data_y.shape)
torch.Size([10000, 1, 28, 28])
torch.Size([10000])
从文件夹中导入数据并预处理
datasets模块中的ImageFolder()函数,能读取png格式的数据集
# 将多个变换成操作组合在一起
train_data_transforms = transforms.Compose([
transforms.RandomResizedCrop(244), # 随机长宽比裁剪为244*244
transforms.RandomHorizontalFlip(), # 依概率p=0.5水平翻转
transforms.ToTensor(), # 转为张量并归一化为[0-1]
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) # 图像标准化
])
# 读取图像
data_dir = './data/imgdata'
train_data = ImageFolder(data_dir, transform = train_data_transforms)
train_data_loader = Data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=1)
print('数据集的label:', train_data.targets)
数据集的label: [0, 1, 2]
# 获得一个batch数据
for step, (b_x, b_y) in enumerate(train_data_loader):
# 检查训练集一个batch样本的维度是否正确
if step > 0:
break
# 输出训练图像的尺寸和标签的尺寸
print(b_x.shape)
print(b_y.shape)
print('图像的取值范围是:', b_x.min(), "~", b_x.max())
torch.Size([3, 3, 244, 244])
torch.Size([3])
图像的取值范围是: tensor(-2.0837) ~ tensor(2.6051)
可见,总共3张图,每张图像是244*244,图像取值范围是-2.08-2.60
文本数据
由于教材上并没有给出数据集,所以我翻墙去kaggle上下了常用的自然语言处理的数据集。
然后利用pandas对数据格式进行了转换。
数据下载地址: https://www.kaggle.com/c/movie-review-sentiment-analysis-kernels-only/data
from sklearn.model_selection import train_test_split
train = pd.read_csv('./data/txtdata/train.tsv', sep='\t')
test = pd.read_csv('./data/txtdata/test.tsv', sep='\t')
print(train.head(5))
PhraseId SentenceId Phrase \
0 1 1 A series of escapades demonstrating the adage ...
1 2 1 A series of escapades demonstrating the adage ...
2 3 1 A series
3 4 1 A
4 5 1 series
Sentiment
0 1
1 2
2 2
3 2
4 2
Sentiment说明:0 - negative,1 - somewhat negative,2 - neutral,3 - somewhat positive,4 - positive
将数据提取[‘Sentiment’,‘Phrase’]列,然后转为csv格式:
train, val = train_test_split(train, test_size=0.2)
print(len(train))
# 将tsv文件转为csv文件,不写入索引,写入Phrase语句列和Sentiment标签列,不写表头
train.to_csv('./data/txtdata/train.csv', index=False, columns=['Sentiment','Phrase']) # 训练集
val.to_csv('./data/txtdata/val.csv', index=False, columns=['Sentiment','Phrase']) # 验证集
test.to_csv('./data/txtdata/test.csv', index=False, columns=['Phrase']) # 测试集
99878
在将kaggle数据转为csv格式后,
我们导入torchtext库的相关函数进行预处理。
from torchtext import data
# 用lambda表达式定义文本切分方法,使用空格切分
mytokensize = lambda x: x.split()
# 定义文本转为张量的相关操作
TEXT = data.Field(sequential = True, # 表明输入的文本是字符
tokenize = mytokensize, # 使用自定义的分词方法
use_vocab = True, # 创建一个词汇表
batch_first = True, # batch优先的数据方式
fix_length = 200 # 每个句子固定长度为200
)
# 定义标签转为张量的相关操作
LABEL = data.Field(sequential = False, # 表明输入的文本是数字
use_vocab = False, # 不创建词汇表
pad_token = None, # 不进行填充
unk_token = None # 没有无法识别的字符
)
# 对数据集的每列进行处理定义
text_data_fields = [('Sentiment', LABEL), ('Phrase', TEXT)]
# 读取训练集和验证集
traindata, testdata = data.TabularDataset.splits(path='./data/txtdata/',
format='csv',
train='train.csv',
fields = text_data_fields,
test = 'val.csv',
skip_header=True)
len(traindata), len(testdata)
(99878, 24970)
成功读取训练集和验证集后,使用data.BucketIterator()将他们定义为数据加载器
# 使用训练集构建单词表,并不指定与训练好的词向量
# max_size表示单词表使用的最大单词数量,vectors用于指定单词的词向量
TEXT.build_vocab(traindata, max_size=1000, vectors=None)
# 将训练数据集定义为数据加载器,便于对模型进行优化
train_iter = data.BucketIterator(traindata, batch_size = 4)
test_iter = data.BucketIterator(testdata, batch_size = 4)
# 获得batch数据
for step, batch in enumerate(train_iter):
if step > 0:
break
# 获得数据的标签和尺寸
print('数据的类别', batch.Sentiment)
print('数据的内容', batch.Phrase)
print('数据的尺寸', batch.Phrase.shape)
数据的类别 tensor([1, 3, 1, 1])
数据的内容 tensor([[ 97, 27, 275, 580, 12, 0, 0, 0, 55, 12, 104, 23, 0, 48,
279, 30, 0, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1],
[118, 13, 7, 2, 0, 264, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1],
[ 11, 302, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1],
[ 79, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1]])
数据的尺寸 torch.Size([4, 200])