存在的问题
在利用pytorch实现简单的图像分类和图像分割时,我们习惯使用datalodaer自带的数据读取方式,一般不会出现问题。但是当我们进行目标检测或是实例分割任务时,经常会遇到图像与图像之间的box数量不同,如果默认读取方式,将导致张量在广播时会出现错误,如下所示。
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 44 in dimension 1
解决的方法
按照tuple格式读取
def collate_fn(batch):
return tuple(zip(*batch))
train_data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)
for images, targets in train_data_loader:
batch_step += 1
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
Reference
https://www.jianshu.com/p/bb90bff9f6e5