0
点赞
收藏
分享

微信扫一扫

(36)RuntimeError: Given groups=4, weight of size [4, 1, 11, 11], expected input xxxxxxxxx

问题描述:

在测试AODNet去雾网络时,计算测试集的指标ssim和psnr,从test_loader中读取测试集的清晰图像和去雾之后的图像,作为ssim的输入进行计算,原文代码如下:

for iter_test, (img_orig, img_dehaze) in enumerate(test_loader):
	ssim_test = ssim(img_orig, img_dehaze)

其中,test_loader的部分代码如下:

	def __getitem__(self, index):

		data_orig_path, data_dehaze_path = self.test_list[index]

		data_orig = Image.open(data_orig_path)
		data_dehaze = Image.open(data_dehaze_path)

		data_orig = data_orig
		data_dehaze = data_dehaze

xxxxxxxxxxxxx

ssim方法的定义如下:

def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx

报错

Traceback (most recent call last):
  File "/home/PycharmProjects/z_fog/AOD-Net/dehaze.py", line 93, in <module>
    ssim_test = ssim(img_orig, img_dehaze)
  File "/home/PycharmProjects/z_fog/AOD-Net/metrics.py", line 64, in ssim
    return _ssim(img1, img2, window, window_size, channel, size_average)
  File "/home/PycharmProjects/z_fog/AOD-Net/metrics.py", line 38, in _ssim
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
RuntimeError: Given groups=4, weight of size [4, 1, 11, 11], expected input[1, 3, 640, 480] to have 4 channels, but got 3 channels instead

问题分析:报错的意思是输入的图像通道数与期望的通道数不符

问题解决:在test_loader中加载测试集的位置,增加对图像的转换,对应的代码改为:

data_orig = data_orig.convert('RGB')
		data_dehaze = data_dehaze.convert('RGB')

举报

相关推荐

0 条评论