0
点赞
收藏
分享

微信扫一扫

【话题讨论】诺奖新纪元:机器学习与神经网络——跨界融合,塑造未来

图像修复-SwinIR: Image Restoration Using Swin Transformer

SwinIR是一个专门用于图像修复任务的基线模型,它基于Swin Transformer架构。相比于基于卷积神经网络的传统方法,SwinIR利用了Transformer在高层次视觉任务中的优异表现。

文章目录

SwinIR架构图

在这里插入图片描述

浅层特征提取
架构代码
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) # 就是简单的卷积,下采样
深层特征提取

在这里插入图片描述

架构代码
self.num_layers = len(depths)  # 设置网络层数,基于给定的 depths 列表长度
self.embed_dim = embed_dim  # 设置嵌入维度
self.ape = ape  # 是否使用绝对位置编码(absolute position embedding)
self.patch_norm = patch_norm  # 是否对patch进行归一化
self.num_features = embed_dim  # 特征通道数等于嵌入维度
self.mlp_ratio = mlp_ratio  # MLP比例系数

# 将图像划分为不重叠的patches
self.patch_embed = PatchEmbed(
    img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
    norm_layer=norm_layer if self.patch_norm else None)  # 设置patch嵌入模块
num_patches = self.patch_embed.num_patches  # 获取patch数量
patches_resolution = self.patch_embed.patches_resolution  # 获取patches的分辨率
self.patches_resolution = patches_resolution  # 保存patches的分辨率

# 将不重叠的patches合并回图像
self.patch_unembed = PatchUnEmbed(
    img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
    norm_layer=norm_layer if self.patch_norm else None)  # 设置patch解嵌入模块

# 绝对位置编码
if self.ape:
    self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))  # 初始化绝对位置编码参数
    trunc_normal_(self.absolute_pos_embed, std=.02)  # 对位置编码进行截断正态分布初始化

    self.pos_drop = nn.Dropout(p=drop_rate)  # 位置编码的dropout层

    # 随机深度(Stochastic Depth)
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 随机深度的递减规则

    # 构建残差Swin Transformer块(RSTB)
    self.layers = nn.ModuleList()  # 保存网络层
    for i_layer in range(self.num_layers):
        # 初始化每层的RSTB模块
        layer = RSTB(dim=embed_dim,
                     input_resolution=(patches_resolution[0], patches_resolution[1]),  # 输入分辨率
                     depth=depths[i_layer],  # 当前层的深度
                     num_heads=num_heads[i_layer],  # 多头注意力机制的头数
                     window_size=window_size,  # 窗口大小
                     mlp_ratio=self.mlp_ratio,  # MLP的比例系数
                     qkv_bias=qkv_bias, qk_scale=qk_scale,  # QKV相关参数
                     drop=drop_rate, attn_drop=attn_drop_rate,  # dropout参数
                     drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # 随机深度的drop路径
                     norm_layer=norm_layer,  # 归一化层
                     downsample=None,  # 不进行下采样
                     use_checkpoint=use_checkpoint,  # 是否使用梯度检查点
                     img_size=img_size,  # 图像大小
                     patch_size=patch_size,  # patch大小
                     resi_connection=resi_connection  # 残差连接类型
                    )
        self.layers.append(layer)  # 将构建的层添加到层列表中
        self.norm = norm_layer(self.num_features)  # 为每层添加归一化操作

        # 构建深度特征提取中的最后一个卷积层
        if resi_connection == '1conv':
            self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)  # 单层卷积作为残差连接
        elif resi_connection == '3conv':
            # 为了节省参数和内存,使用三层卷积残差连接
            self.conv_after_body = nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),  # 第一个卷积层,减少通道数
                nn.LeakyReLU(negative_slope=0.2, inplace=True),  # 激活函数LeakyReLU
                nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),  # 第二个1x1卷积层,保持通道数
                nn.LeakyReLU(negative_slope=0.2, inplace=True),  # 激活函数LeakyReLU
                nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)  # 第三个卷积层,恢复通道数
            )
高质量图像重建
架构代码
if self.upsampler == 'pixelshuffle':
    # 针对经典的超分辨率(SR)
    # 在上采样之前的卷积层,3x3卷积用于特征处理,LeakyReLU用于激活
    self.conv_before_upsample = nn.Sequential(
        nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
        nn.LeakyReLU(inplace=True)
    )
    # 使用PixelShuffle进行上采样
    self.upsample = Upsample(upscale, num_feat)
    # 最后的卷积层,用于生成输出图像,3x3卷积
    self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

elif self.upsampler == 'pixelshuffledirect':
    # 针对轻量级超分辨率(SR),为了减少参数量
    # 使用一步到位的PixelShuffle直接上采样
    self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
                                    (patches_resolution[0], patches_resolution[1]))

elif self.upsampler == 'nearest+conv':
    # 针对真实世界超分辨率(SR),减少伪影
    # 在上采样之前的卷积层,3x3卷积处理特征,LeakyReLU用于激活
    self.conv_before_upsample = nn.Sequential(
        nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
        nn.LeakyReLU(inplace=True)
    )
    # 第一个上采样卷积层,使用3x3卷积
    self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
    
    if self.upscale == 4:
        # 如果上采样比例为4,则需要第二个卷积层
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # 高分辨率特征的卷积层,3x3卷积
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # 最后的卷积层,生成最终输出图像,3x3卷积
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
        # LeakyReLU激活函数
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
    # 针对图像去噪和JPEG压缩伪影减少
    # 最后的卷积层,3x3卷积,用于生成输出图像
    self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
损失函数
图像超分辨率(SR)

在这里插入图片描述

图像去噪和JPEG压缩伪影去除

在这里插入图片描述

举报

相关推荐

0 条评论