0
点赞
收藏
分享

微信扫一扫

pytorch创建自定义数据集

Alex富贵 2022-02-18 阅读 79

pytorch为我们提供了Dataset类来提供所用数据集的创建任务。
数据集有两种情况:
1.pytorch中写好的数据集,如CIFAR10,我们在使用该数据集时只需要以下代码:data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
datasets.CIFAR10就是Datasets的一个子类,data是这个类的一个实例。
2.利用Dataset自定义数据集:
模板为

class MovingMNISTdataset(Dataset):#需要继承Dataset类
    ##dataset class for moving MNIST data
    ##Initialize
    def __init__(self, path):
        self.path = path
        self.data = MNISTdataLoader(path)

    def __len__(self):
        return len(self.data[:, 0, 0, 0])

    def __getitem__(self, indx):
        ##getitem method
        self.trainsample_ = self.data[indx, ...]
        self.sample_ = self.trainsample_/255.0

        self.sample = torch.from_numpy(np.expand_dims(self.sample_, axis = 1)).float()
        return self.sample

其中 getitem(self, index), len(self) 两个内建方法,用来表示从索引到样本的映射(Map).

举报

相关推荐

0 条评论