0
点赞
收藏
分享

微信扫一扫

Pytorch 加载自有数据集

Pytorch 加载自有数据集

抽象基类 data.Dataset

class Dataset(Generic[T_co]):
	r"""表示数据集的抽象类
	所有数据集都是该类的子类。子类必须重写 `__getitem__` 方法,
	实现通过 key 获取数据样本;子类也可以重写 `__len___` 方法,
	来获取数据集的尺寸。
	"""

继承基类自定义数据集

# 自定义数据集类
class MyDataset(torch.utils.data.Dataset):
    # 
    def __init__(self, *args):
        super().__init__()
        # 初始化数据集包含的数据和标签
        pass
        
    def __getitem__(self, index):
        # 根据索引index从文件中读取一个数据
        # 对数据预处理
        # 返回数据即对应标签
        pass
    
    def __len__(self):
        # 返回数据集的大小
        return len()

示例
本地图片保存在文件夹 “~/train_data” 中:
0.png
1.png
2.png
3.png
4.png
5.png

标签存放在对应的 txt 文件中:
0.png 0
1.png 0
2.png 2
3.png 1
4.png 3
5.png 2

创建自定义数据集类:

class MyData(torch.utils.data.Dataset):
    def __init__(self, data_path, label_path):
        self.data_path = data_path
        self.label_path = label_path
        self.imgs = []
    
        with open(labels_path, 'r') as fp:
            for line in fp:
                line = line.strip()
                sample = line.split(' ')
                self.imgs.append((sample[0], sample[1]))
                
    def __len__():
        return len(self.imgs)
    
    def __getitem__(self, index):
        picture, label = self.imgs[index]
        picture = transforms.ToTensor()(Image.open(self.data_path+'/picture'))
        return picture, label

举报

相关推荐

0 条评论