0
点赞
收藏
分享

微信扫一扫

Torch_3_Dataset与Dataloader

Python芸芸 2022-05-01 阅读 58

介绍


  • 看代码的过程中不难发现,不同作者模型训练时的数据输入方法差别非常大。
  • torch提供了统一的接口,通过迭代器实现数据和标签的读取,使用方便也利于阅读。


实现方法


  • 导入

    from torch.utils.data import Dataset, DataLoader
    
  • Dataset

    • torch内置抽象类,无法实例化,通过继承并重写魔术方法实现
    class MyDataset(Dataset):
        def __init__(self, filepath):
            xy = np.load(filepath)
            self.len = xy.shape[0]
            self.x_data = torch.from_numpy(xy[:, :-1])
            self.y_data = torch.from_numpy(xy[:, [-1]])
    
        def __getitem__(self, item):
            return self.x_data[item], self.y_data[item]
    
        def __len__(self):
            return self.len
    
    dataset = MyDataset('MyData.npy')
    
    • 示例中,以读取numpy文件为例,通过重写__getitem____len__方法,实现数据的随机读取

  • Dataloader

    • 调用dataset 实例,通过设定的参数可生成DataLoader
    train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
    

  • 训练中调用数据

    for i, data in enumerate(train_loader, 0):  #
    		x, y = data
    
举报

相关推荐

0 条评论