如果读者正在从事深度学习的项目,通常大部分时间都花在了处理数据上,而不是神经网络上。因为数据就像是网络的燃料:它越合适,结果就越快、越准确!神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
参考:https://zhuanlan.zhihu.com/p/596730297
DataLoader加载和迭代数据集
Dataloader本质是一个迭代器对象,也就是可以通过for batch_idx
,batch_dict in dataloader
来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)
from torch.utils.data import DataLoader
DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
以U-Net中的代码为例:
具体详见:U-Net代码复现
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
1. 数据集
**dataset (Dataset) ** – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. getitem && len )
2. 对数据进行批处理
batch_size (int, optional) – how many samples per batch to load(default: 1).
3. 在 CUDA 张量上加载数据
pin_memory(bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.
pin_memory参数直接将数据集加载为 CUDA 张量。它是一个可选参数,接受一个布尔值;如果设置为True,会在返回张量之前将张量复制到 CUDA 固定内存中。这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。
通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用pin_memory可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。
需要注意的是,使用pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory并不会带来明显的加速效果。
4.允许多进程
num_workers (int, optional) – how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为0,即在主进程中进行数据加载,而不使用额外的子进程。
5.合并数据集
collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
整合多个样本到一个batch时需要调用的函数,当 getitem 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求。
以U-Net为例:
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
img_file = list(self.images_dir.glob(name + '.*'))
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
mask = load_image(mask_file[0])
img = load_image(img_file[0])
assert img.size == mask.size, \
f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)
return {
'image': torch.as_tensor(img.copy()).float().contiguous(),
'mask': torch.as_tensor(mask.copy()).long().contiguous()
}
getitem 返回的是一个包含image和mask的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包(待补充。。。)
6.数据采样
sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shufflemust not be specified.