Pytorch 加载自有数据集
抽象基类 data.Dataset
。
class Dataset(Generic[T_co]):
r"""表示数据集的抽象类
所有数据集都是该类的子类。子类必须重写 `__getitem__` 方法,
实现通过 key 获取数据样本;子类也可以重写 `__len___` 方法,
来获取数据集的尺寸。
"""
继承基类自定义数据集
# 自定义数据集类
class MyDataset(torch.utils.data.Dataset):
#
def __init__(self, *args):
super().__init__()
# 初始化数据集包含的数据和标签
pass
def __getitem__(self, index):
# 根据索引index从文件中读取一个数据
# 对数据预处理
# 返回数据即对应标签
pass
def __len__(self):
# 返回数据集的大小
return len()
示例:
本地图片保存在文件夹 “~/train_data” 中:
0.png
1.png
2.png
3.png
4.png
5.png
…
标签存放在对应的 txt 文件中:
0.png 0
1.png 0
2.png 2
3.png 1
4.png 3
5.png 2
…
创建自定义数据集类:
class MyData(torch.utils.data.Dataset):
def __init__(self, data_path, label_path):
self.data_path = data_path
self.label_path = label_path
self.imgs = []
with open(labels_path, 'r') as fp:
for line in fp:
line = line.strip()
sample = line.split(' ')
self.imgs.append((sample[0], sample[1]))
def __len__():
return len(self.imgs)
def __getitem__(self, index):
picture, label = self.imgs[index]
picture = transforms.ToTensor()(Image.open(self.data_path+'/picture'))
return picture, label