0
点赞
收藏
分享

微信扫一扫

四. torchvision 数据集

张宏涛心理 2022-03-12 阅读 69
pytorch

1.  基本功能函数使用讲解

import torchvision

# root 数据集位置,train 训练集或测试集,download 是否下载
train_set = torchvision.datasets.CIFAR10( root ="./dataset",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)

print(test_set[0])
# (<PIL.Image.Image image mode=RGB size=32x32 at 0x7F52D13E00D0>, 3)

print(test_set.classes)
# ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img, target = test_set[0]
print(test_set.classes[target])
# cat
img.show()
# 显示img

2. dataset + transform 运用

      transform 把图片转为tensor ,并用tensorboard 显示

from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

#使用transforms 转为tensor工具
tensor_compose = transforms.Compose([transforms.ToTensor()])  
writer = SummaryWriter("logs")

# root 数据集位置,train 训练集或测试集,download 是否下载, transform: img -> tensor
train_set = torchvision.datasets.CIFAR10( root ="./dataset",train=True, transform = tensor_compose, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform = tensor_compose, download=True)

for i in range(10):
    img, target = train_set[i]
    writer.add_image("dataset",img,i)

writer.close()

举报

相关推荐

0 条评论