图像修复-SwinIR: Image Restoration Using Swin Transformer
SwinIR是一个专门用于图像修复任务的基线模型,它基于Swin Transformer架构。相比于基于卷积神经网络的传统方法,SwinIR利用了Transformer在高层次视觉任务中的优异表现。
文章目录
SwinIR架构图
浅层特征提取
架构代码
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) # 就是简单的卷积,下采样
深层特征提取
架构代码
self.num_layers = len(depths) # 设置网络层数,基于给定的 depths 列表长度
self.embed_dim = embed_dim # 设置嵌入维度
self.ape = ape # 是否使用绝对位置编码(absolute position embedding)
self.patch_norm = patch_norm # 是否对patch进行归一化
self.num_features = embed_dim # 特征通道数等于嵌入维度
self.mlp_ratio = mlp_ratio # MLP比例系数
# 将图像划分为不重叠的patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None) # 设置patch嵌入模块
num_patches = self.patch_embed.num_patches # 获取patch数量
patches_resolution = self.patch_embed.patches_resolution # 获取patches的分辨率
self.patches_resolution = patches_resolution # 保存patches的分辨率
# 将不重叠的patches合并回图像
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None) # 设置patch解嵌入模块
# 绝对位置编码
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # 初始化绝对位置编码参数
trunc_normal_(self.absolute_pos_embed, std=.02) # 对位置编码进行截断正态分布初始化
self.pos_drop = nn.Dropout(p=drop_rate) # 位置编码的dropout层
# 随机深度(Stochastic Depth)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 随机深度的递减规则
# 构建残差Swin Transformer块(RSTB)
self.layers = nn.ModuleList() # 保存网络层
for i_layer in range(self.num_layers):
# 初始化每层的RSTB模块
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]), # 输入分辨率
depth=depths[i_layer], # 当前层的深度
num_heads=num_heads[i_layer], # 多头注意力机制的头数
window_size=window_size, # 窗口大小
mlp_ratio=self.mlp_ratio, # MLP的比例系数
qkv_bias=qkv_bias, qk_scale=qk_scale, # QKV相关参数
drop=drop_rate, attn_drop=attn_drop_rate, # dropout参数
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # 随机深度的drop路径
norm_layer=norm_layer, # 归一化层
downsample=None, # 不进行下采样
use_checkpoint=use_checkpoint, # 是否使用梯度检查点
img_size=img_size, # 图像大小
patch_size=patch_size, # patch大小
resi_connection=resi_connection # 残差连接类型
)
self.layers.append(layer) # 将构建的层添加到层列表中
self.norm = norm_layer(self.num_features) # 为每层添加归一化操作
# 构建深度特征提取中的最后一个卷积层
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) # 单层卷积作为残差连接
elif resi_connection == '3conv':
# 为了节省参数和内存,使用三层卷积残差连接
self.conv_after_body = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), # 第一个卷积层,减少通道数
nn.LeakyReLU(negative_slope=0.2, inplace=True), # 激活函数LeakyReLU
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), # 第二个1x1卷积层,保持通道数
nn.LeakyReLU(negative_slope=0.2, inplace=True), # 激活函数LeakyReLU
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1) # 第三个卷积层,恢复通道数
)
高质量图像重建
架构代码
if self.upsampler == 'pixelshuffle':
# 针对经典的超分辨率(SR)
# 在上采样之前的卷积层,3x3卷积用于特征处理,LeakyReLU用于激活
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)
)
# 使用PixelShuffle进行上采样
self.upsample = Upsample(upscale, num_feat)
# 最后的卷积层,用于生成输出图像,3x3卷积
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# 针对轻量级超分辨率(SR),为了减少参数量
# 使用一步到位的PixelShuffle直接上采样
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# 针对真实世界超分辨率(SR),减少伪影
# 在上采样之前的卷积层,3x3卷积处理特征,LeakyReLU用于激活
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)
)
# 第一个上采样卷积层,使用3x3卷积
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
# 如果上采样比例为4,则需要第二个卷积层
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# 高分辨率特征的卷积层,3x3卷积
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# 最后的卷积层,生成最终输出图像,3x3卷积
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
# LeakyReLU激活函数
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# 针对图像去噪和JPEG压缩伪影减少
# 最后的卷积层,3x3卷积,用于生成输出图像
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)