0
点赞
收藏
分享

微信扫一扫

PytorchVision Transforms的基本使用

zhaoxj0217 2022-02-06 阅读 59

Transforms的基本概念

transforms是torchvision下的一个模块,主要帮助用户方便的对图像数据进行处理

它要求数据是(C, H, W)的三维数组,其中字母含义为:

  • C: Channel, 图片的通道,例如R、G、B
  • H,W, Height, Weight,图片的宽高

使用PIL读取一张图片

在使用Transforms前,先读取一张图片,用于后续使用

from PIL import Image

image = Image.open("images/mary.jpg")
image

在这里插入图片描述

Transforms的常用方法

Transforms的常用方法有如下

1.ToTensor(): 将一个PIL Image或一个numpy.ndarray转为Tensor

trans = transforms.ToTensor()
img_data = trans(image)
img_data.shape
torch.Size([3, 225, 225])

输出[3, 255, 255]表示有3个通道(R,G,B),每个通道有255x255个像素点

2.Normalize(mean, std, inplace=False): 将tensor归一化为均值为mean,方差为std的数据

# 将三个通道分别做归一化
# 第一个通道归一化为 均值为0,方差为1
# 第二个通道归一化为 均值为1,方差为2
# 第三个通道归一化为 均值为2,方差为3
img_data = transforms.Normalize(mean=(0, 1, 2), std=(1,2,3))(img_data)
img_data.shape
torch.Size([3, 225, 225])

Transforms的Compose方法

一张图片可能需要执行很多次Transforms方法,所以Transform提供了Compose方法,方便用户一次将其全部处理完毕

img_data = transforms.ToTensor()(image)
img_data = transforms.Normalize(mean=(0, 1, 2), std=(1,2,3))(img_data)

compose = transforms.Compose(
    [ # 将要对图片做的处理,全部一次性写全
          transforms.ToTensor(),
          transforms.Normalize(mean=(0, 1, 2), std=(1,2,3))
    ]
)
compose(image).equal(img_data)
True





参考资料

transforms官方文档:https://pytorch.org/vision/stable/transforms.html

举报

相关推荐

0 条评论