0
点赞
收藏
分享

微信扫一扫

Pytorch 加载数据集的几种方法

最后的执着 2022-04-26 阅读 48

Pytorch 加载数据集的几种方法

总结

方案1:

 

 

方案2:

train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

在for循环中调用

for i, (images, labels) in enumerate(train_loader):
举报

相关推荐

0 条评论