dir() 函数: 让我们知道工具箱以及工具箱中的分隔区有什么东西
help()函数: 让我们知道每个工具箱如何使用,工具的使用方法
Pytorch加载数据
先来一个小案例:
根据输入得到对应的labels
将数据集dataset放入项目的文件目录下。
在pycharm中选中图片,CTRL+SHIFT+C
复制文件所在得绝对路径,复制之后要手动在”\“之后加上”\“表示转义,在控制台中打开图片:img = Image.open(img_path)
控制台右边可以直接查看对象的属性,也可以用命令查看对象的相关属性
可以使用img.show() 将图片显示出来
复制相对地址 Copy Path --> Path From Content Root
from torch.utils.data import Dataset
from PIL import Image # 用于读取图片
import os # 用于获取一些文件信息,如地址信息
"""
创建MyData类,继承自Dataset
"""
class MyData(Dataset):
"""
根据MyData创建实例时,会自动运行 __init__(self) 函数
一般用于为整个class类提供全局变量
root_dir: 根目录(相对来说)
label_dir: 图片上一级目录的名称
"""
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 应该是将传入的参数设为全局变量
self.label_dir = label_dir # 应该是将传入的参数设为全局变量
self.path = os.path.join(root_dir, label_dir) # 将路径拼接起来,得到文件所在路径
self.img_path = os.listdir(self.path) # 获取所有图片的列表
"""
数据集本质应当是所有数据样本的一个列表,因此每个样本都有对应的索引index。
我们取用一个样本最简单的方式就是用该样本的index从数据列表中把它取出来。__getitem__就是做这样一件事。
具体来说,Dataset可以简单想成一个列表: datalist = [sample_0, sample_1, ..., sample_n - 1],
__getitem__做的事情就是返回第index个样本的具体数据: return datalist[index]。
"""
def __getitem__(self, index):
img_name = self.img_path[index] # 变量 img_name 是图片列表的中一个图片的名称(懂意思就行)
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 文件的相对路径
img = Image.open(img_item_path) # 打开文件,打开 != 显示
label = self.label_dir # label 就是文件所在目录
return img, label # 返回img和label
"""
数据集有多长
"""
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = MyData(root_dir, ants_label_dir) # 蚂蚁数据集
bees_dataset = MyData(root_dir, bees_label_dir) # 蜜蜂数据集
train_dataset = ants_dataset + bees_dataset # 将两个数据集合并
img1, label = train_dataset[123]
img1.show()
img2, label2 = train_dataset[124]
img2.show()