0
点赞
收藏
分享

微信扫一扫

[GAN]FT-TDR: Frequency-guided Transformer and Top-Down Refinement Network论文讲解

文章目录

FT-TDR: Frequency-guided Transformer and Top-Down Refinement Network for Blind Face Inpainting

  • FT-TDR(Frequency-guided Transformer and Top-Down Refinement Network) = Transformer-based Mask Detection + TDRefine Inpainting

Preprocess:Training Data Generation

  • Generate the masked image for training:

I m c o n = I g t ⨀ ( 1 − M b i n ) + M c o n ⨀ M b i n I_{m_{con}} = I_{gt}\bigodot(1-M_{bin}) + M_{con}\bigodot M_{bin} Imcon=Igt(1Mbin)+MconMbin

I g t I_{gt} Igt: uncorrupted ground truth face image

I m c o n I_{m_{con}} Imcon: masked image that model generated

M b i n M_{bin} Mbin: binary mask

{ M b i n = 0 , k n o w n   p i x e l s M b i n = 1 , c o r r p u t e d   p i x e l s \begin{cases} M_{bin} = 0, known\ pixels \\ M_{bin}=1, corrputed\ pixels \end{cases} {Mbin=0,known pixelsMbin=1,corrputed pixels

M c o n M_{con} Mcon: noisy visual signal

⨀ \bigodot : Hadamard product operator

  • Note: 训练过程中, M b i n M_{bin} Mbin M c o n M_{con} Mcon是随机选择的,两者之间没有关联。

Part 1:Tranformer-based Mask Detection Module

  • 利用 θ m \theta_m θm参数化建模 G M \mathcal{G}_{M} GM

M ^ b i n = G M ( I m c o n ; θ m ) \hat{M}_{b i n}=\mathcal{G}_{M}\left(I_{m_{c o n}} ; \theta_{m}\right) M^bin=GM(Imcon;θm)

M ^ b i n \hat{M}_{b i n} M^bin: predicted binary mask

  • 将生成的掩膜 M ^ b i n \hat{M}_{b i n} M^bin和受损的输入图片 I m c o n I_{m_{con}} Imcon融合,得到binary-masked image I m b i n I_{m_{bin}} Imbin,作为下一阶段Part 2 的输入。

I m b i n = I m c o n ⨀ ( 1 − M ^ b i n ) + M ^ b i n I_{m_{bin}} = I_{m_{con}}\bigodot(1-\hat{M}_{b i n}) +\hat{M}_{b i n} Imbin=Imcon(1M^bin)+M^bin

  • TMDM(Transformer-based Mask Detection Module) = FAD(Frequency Anomaly Detector) + CViT(Convolutional Vision Transformer)
  • 损失函数:交叉熵损失+dice loss

FAD(Frequency Anomaly Detector)

  • 利用一下方法检测视觉不一致区域信号:

D high  = F ( D C T ( I m con  ) , α ) D_{\text {high }}=\mathcal{F}\left(D C T\left(I_{m_{\text {con }}}\right), \alpha\right) Dhigh =F(DCT(Imcon ),α)

首先利用DCT变换将 I m con  ∈ R H × W × 3 I_{m_{\text {con }}}\in\mathbb{R}^{H×W×3} Imcon RH×W×3从RGB空间转换为频率空间;然后利用手工过滤器过滤低频信息,从而得到视觉不一致区域的高频信息信号。

F \mathcal{F} F: high-pass filter

α \alpha α: manually-chosen threshold,控制要过滤掉的频率分量

D high  ∈ R H × W × 1 D_{\text {high }}\in \mathbb{R}^{H×W×1} Dhigh RH×W×1: 输入受损图像 I m con  I_{m_{\text {con }}} Imcon 的高频率分量

  • 利用逆DCT将滤波后的信号转换回RGB空间,得到frequency-aware representation(目的是为了保持变换的invariance和自然图像局部一致性blablabla…)

F = D C T − 1 ( D h i g h ) , F ∈ R H × W × 1 F=D C T^{-1}\left(D_{h i g h}\right), F\in\mathbb{R}^{H×W×1} F=DCT1(Dhigh),FRH×W×1

CViT(Convolutional Vision Transformer)

  • 将输入图像 I m con  I_{m_{\text {con }}} Imcon reshaped into a squence of flattened patches

{ z i } i = 1 P ,  where  z i ∈ R 1 × Q , P 为 序 列 的 长 度 \left\{z_{i}\right\}_{i=1}^{P}, \text { where } z_{i} \in \mathbb{R}^{1 \times Q},P为序列的长度 {zi}i=1P, where ziR1×QP

在这篇文章中,P=64,Q= ( H / 8 ) × ( W / 8 ) × 64 (H/8)\times(W/8)\times 64 (H/8)×(W/8)×64

reshaped之后就add position embeddings,得到最终的feature vector然后输入进L stacked transformer encoders。

在这里插入图片描述

  • 每个encoder层都是标准结构(多头注意力块+多层感知机)

对于每个head h h h,利用全连接层将feature vector和query,key,value embedding q h , k h , v h q^h,k^h,v^h qh,kh,vh分别映射。然后利用如下公式计算不同heads的注意力图 A h A^h Ah
A i , j h = Softmax ⁡ ( q i h k j h ⊤ d k ) A_{i, j}^{h}=\operatorname{Softmax}\left(\frac{q_{i}^{h} k_{j}^{h \top}}{\sqrt{d_{k}}}\right) Ai,jh=Softmax(dk qihkjh)
i,j: position index

d k d_k dk: key embedding的维度

然后,将得到的不同head的注意图沿第一维度连接起来,以获得最终的self-attention 图 A ∈ R N h × P × P A∈ R^{N_h×P×P} ARNh×P×P N h N_h Nh是head数量。

  • 注意力图中加入频率分量。首先利用几个卷积层将frequency-aware representation F转换为频率模态注意力图 A frep  ∈ R C × P × P A_{\text {frep }} \in \mathbb{R}^{C \times P \times P} Afrep RC×P×P

A dual  = Softmax ⁡ ( conv ⁡ 1 × 1 ( [ A , A f r e q ] ) ) A_{\text {dual }}=\operatorname{Softmax}\left(\operatorname{conv}_{1 \times 1}\left(\left[A, A_{f r e q}\right]\right)\right) Adual =Softmax(conv1×1([A,Afreq]))

首先先将 A , A f r e q A, A_{f r e q} A,Afreq沿着第一维度连接concat。然后利用一个1x1的卷积将通道数转换回 N h N_h Nh。而后计算output=weighted summation of the attention weights and values v h v^h vhof relavant patches,output和input之间有个一残差连接,而后送进MLP处理。

  • 将序列输出reassemble成2D feature map T ∈ R H / 8 × W / 8 × 64 T\in\mathbb{R}^{H/8×W/8×64} TRH/8×W/8×64
  • 在transformer架构顶部加入Patch Similarity Block(PS Block),获得边缘图E:

E i = 1 ∣ Ω ∣ ∑ j ∈ Ω Sim ⁡ ( T i , T j ) E_{i}=\frac{1}{|\Omega|} \sum_{j \in \Omega} \operatorname{Sim}\left(T_{i}, T_{j}\right) Ei=Ω1jΩSim(Ti,Tj)

Ω \Omega Ω: a small neighboring patch in the feature map T around i(文中将该数值设为9)

Sim(,): 使用余弦相似度

  • 然后,将边缘图E和输入特征图T相加得到一个保留边缘的特征图 F e d g e F_{edge} Fedge

在这里插入图片描述

  • 最后利用连续的双线性上采样层和1x1卷积逐步提高 F e d g e F_{edge} Fedge的空间分辨率,并获得最终的掩膜检测结果 M ^ b i n \hat{M}_{b i n} M^bin

Part 2 TDRefine Inpainting Module

  • 模型 G P \mathcal{G}_{P} GP基于encoder-decoder基础架构。

I pred  = G P ( I m b i n , M ^ b i n ; θ p ) I_{\text {pred }}=\mathcal{G}_{P}\left(I_{m_{b i n}}, \hat{M}_{b i n} ; \theta_{p}\right) Ipred =GP(Imbin,M^bin;θp)

其中 θ p \theta_p θp表示网络参数

  • 首先将 I m b i n I_{m_{b i n}} Imbin作为输入,利用Landmark Prediction Module算法(略)获得面部关键点 L l d m k L_{ldmk} Lldmk
  • 然后将 L l d m k L_{ldmk} Lldmk I m b i n I_{m_{b i n}} Imbin M ^ b i n \hat{M}_{b i n} M^bin连接(concat),作为第一个TDRB block的输入。
  • TDRefine模块的自底向上路径包含一个编码器,该编码器会下采样两次,然后是7个扩张卷积的残差块和一个长短期注意力块。
  • 自上而下的路径中有几个Top-Down Refinement fusion blocks(TDRB)

ϕ d i + 1 = TDRB ⁡ ( ϕ d i , ϕ e i + 1 , m i + 1 ) \phi_{d_{i+1}}=\operatorname{TDRB}\left(\phi_{d_{i}}, \phi_{e_{i+1}}, m_{i+1}\right) ϕdi+1=TDRB(ϕdi,ϕei+1,mi+1)

ϕ d i \phi_{d_{i}} ϕdi: 自上而下路径生成的特征图

ϕ e i + 1 \phi_{e_{i+1}} ϕei+1: 编码器层的特征图

m i + 1 m_{i+1} mi+1: 预测的掩膜,表明代填充区域

首先利用de-convolution层上采样 ϕ d i \phi_{d_{i}} ϕdi,使之和 ϕ e i + 1 \phi_{e_{i+1}} ϕei+1大小相同,然后将两者根据掩膜 m i + 1 m_{i+1} mi+1提供的信息融合:
ϕ ~ d i + 1 = conv ⁡ 1 × 1 ( ( deconv ⁡ ( ϕ d i ) ⊙ m i + 1 + ϕ e i + 1 ⊙ ( 1 − m i + 1 ) ) \tilde{\phi}_{d_{i+1}}=\operatorname{conv}_{1 \times 1}\left(\left(\operatorname{deconv} \left(\phi_{d_{i}}\right) \odot m_{i+1}+\phi_{e_{i+1}} \odot\left(1-m_{i+1}\right)\right)\right. ϕ~di+1=conv1×1((deconv(ϕdi)mi+1+ϕei+1(1mi+1))
然后利用区域归一化算法(略)来均衡掩膜内外的特征:
ϕ ˉ d i + 1 = R N ( ϕ ~ d i + 1 , m i + 1 ) \bar{\phi}_{d_{i+1}}=R N\left(\tilde{\phi}_{d_{i+1}}, m_{i+1}\right) ϕˉdi+1=RN(ϕ~di+1,mi+1)
ϕ ˉ d i + 1 \bar{\phi}_{d_{i+1}} ϕˉdi+1经过一次卷积层生成refined feature ϕ d i + 1 {\phi}_{d_{i+1}} ϕdi+1,每一个refinement block的输入 M ^ b i n \hat{M}_{b i n} M^bin大小都不一样。

在这里插入图片描述

  • 引入PatchGAN,图中未展示判别器部分。
  • 损失函数:

L inpaiting  = λ recons  L recons  + λ a d v L a d v G + λ perc  L perc  + λ style  L style  + λ t v L t v \begin{aligned} \mathcal{L}_{\text {inpaiting }}=& \lambda_{\text {recons }} \mathcal{L}_{\text {recons }}+\lambda_{a d v} \mathcal{L}_{a d v_{G}}+\\ & \lambda_{\text {perc }} \mathcal{L}_{\text {perc }}+\lambda_{\text {style }} \mathcal{L}_{\text {style }}+\lambda_{t v} \mathcal{L}_{t v} \end{aligned} Linpaiting =λrecons Lrecons +λadvLadvG+λperc Lperc +λstyle Lstyle +λtvLtv

训练过程

  • 首先,训练掩模检测模块。
  • 然后,以端到端的方式联合训练掩模检测模块和修复模块。

和VCNet对比:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8XhUEUQk-1645010284205)(C:\Users\Vanessa Ni\AppData\Roaming\Typora\typora-user-images\image-20220216184411488.png)]

References

  1. 在线LaTeX公式编辑器-编辑器 (latexlive.com)
  2. 论文解释:Vision Transformers和CNN看到的特征是相同的吗?
举报

相关推荐

0 条评论