问题描述:
在测试AODNet去雾网络时,计算测试集的指标ssim和psnr,从test_loader中读取测试集的清晰图像和去雾之后的图像,作为ssim的输入进行计算,原文代码如下:
for iter_test, (img_orig, img_dehaze) in enumerate(test_loader):
ssim_test = ssim(img_orig, img_dehaze)
其中,test_loader的部分代码如下:
def __getitem__(self, index):
data_orig_path, data_dehaze_path = self.test_list[index]
data_orig = Image.open(data_orig_path)
data_dehaze = Image.open(data_dehaze_path)
data_orig = data_orig
data_dehaze = data_dehaze
xxxxxxxxxxxxx
ssim方法的定义如下:
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
报错:
Traceback (most recent call last):
File "/home/PycharmProjects/z_fog/AOD-Net/dehaze.py", line 93, in <module>
ssim_test = ssim(img_orig, img_dehaze)
File "/home/PycharmProjects/z_fog/AOD-Net/metrics.py", line 64, in ssim
return _ssim(img1, img2, window, window_size, channel, size_average)
File "/home/PycharmProjects/z_fog/AOD-Net/metrics.py", line 38, in _ssim
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
RuntimeError: Given groups=4, weight of size [4, 1, 11, 11], expected input[1, 3, 640, 480] to have 4 channels, but got 3 channels instead
问题分析:报错的意思是输入的图像通道数与期望的通道数不符
问题解决:在test_loader中加载测试集的位置,增加对图像的转换,对应的代码改为:
data_orig = data_orig.convert('RGB')
data_dehaze = data_dehaze.convert('RGB')