1. DataLoader
用于构建可迭代的数据加载器,训练时每一个iteration就是从DATa Loader中获取一个batch_size大小的数据。
参数
- dataset: Dataset类。是要自定义编写的,继承自
torch.utils.data.Dataset
的类。 - batchsize
- num_works
- shuffle
- drop_last:当样本数不是batchsize的整数倍时,是否舍弃最后一组数据
2. Dataset
所有自定义的Dataset都要继承于torch.utils.data.Dataset
,并且必须复写__getitem__()
方法。
ImageFolder
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存放同一类数据,文件夹名为类名。
参数
- root:图片路径
- transform:对PIL Image进行的转换参数
- target_transform:对标签进行转换
- loader:如何读取图片,默认读取为RGB格式的 PIL Image对象
3. 构造数据集示例
- 文件夹格式
train_path = r'dataset/train'
- 预处理
train_transform = transforms.Compose([
transforms.Resize((64, 64)), # 数据格式转换
transforms.RandomCrop(40,padding=4), # 随机裁剪
transforms.ToTensor(), # 声明为张量,便于pytorch计算
transforms.Normalize([0.485,0.456,0.406],
[0.229,0.224,0.225],) # 对数据按通道进行标准化
])
- 自定义Dataset
class myData(Data.Dataset): # 继承自抽象类
def __init__(self, path, transform):
self.path = path
self.transform = transform
self.data_info = self.get_img_info(path) # 将path文件夹中的数据以元组的列表形式保存
self.label = []
for i in range(len(self.data_info)):
self.label.append(list(self.data_info[i])[1]) # 保存标签,与data_info对齐
def __getitem__(self, idx):
# 复写,根据idx返回 图像数据(转为张量), 标签, 索引
path_img = self.data_info[idx][0]
label = self.label[idx]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label, idx
def __len__(self):
return len(self.data_info)
@staticmethod
# 声明为静态方法, 即可以实例化调用也可以不实例化直接调用函数。
def get_img_info(data_dir):
'''
:return :返回一个列表,列表的每个元素是一个元组(图片地址, 图片标签)
'''
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = int(sub_dir)
data_info.append((path_img, int(label)))
return data_info
- 创建DataLoader
# 数据集
trainset = myData(
train_path,
train_transform
)
# 数据发生器
train_loader = Data.DataLoader(
dataset=trainset,
batch_size=4,
shuffle = True
)