❤️【深度学习入门项目】❤️ 之 【超分重建】
❤️ 原创:墨理学AI
本博文专注于记录 BSRGAN 训练过程,是如下深度学习入门博文的续写
BSRGAN -- 环境搭建 - 测试部分 - 移步该博文 【如何让女神更清晰,我的白月光】
📔 基础信息
🟧 代码仓库
- 训练代码仓库:https://github.com/cszn/KAIR
下载仓库代码
git clone https://github.com/cszn/KAIR.git
🟨 简单说明
该论文代码本身 致力于 设计一个实用的退化模型【Designing a Practical Degradation Model】 因此,如果要复现该论文,还是需要具体看论文,来下载对应的 HR 数据 和 测试数据 具体研究和复现,也不是本专栏博文重点,因此,这里只是简单记录该代码的训练步骤
📕 训练数据集准备
github.com/cszn/KAIR 主页有数据下载链接
这里使用了 DIV2K_train_HR/ 800 张 *.png 图片进行训练
📗 训练参数配置
关于训练,ReadMe 其实交代的比较清楚了
🔴 在进行重现实验时,主要会调整的基本参数如下
- BSRNet 对应配置为 :options/train_bsrgan_x4_psnr.json
- BSRGAN 对应配置为 :options/train_bsrgan_x4_gan.json
- 两个配置的待修改参数相似
🔵 BSRNet 和 BSRGAN 配置上的区别
🟣 Train BSRNet
batch_size 为 4时,单卡 GPU占用 6933MiB ,90分钟 训练 iter: 10,000次,【数据集为 DIV2k 800 训练数据】
cd KAIR/
#环境搭建,参考上篇博文即可
conda activate torch18
python main_train_psnr.py --opt options/train_bsrgan_x4_psnr.json
可以在 log 日志 【KAIR/superresolution/bsrgan_x4_psnr/train.log】中看到当训练 iter: 10,000 时,打印了 set5 测试的PSNR 效果如下【因为我在配置中设置 "checkpoint_test": 10000 】
这种设置对于【魔改】模型训练,及时发现训练效果是否提升,会有一定帮助;
21-09-15 04:03:32.434 : <epoch: 49, iter: 10,000, lr:1.000e-04> G_loss: 6.757e-02
21-09-15 04:03:32.435 : Saving the model.
21-09-15 04:03:34.821 : ---1--> baby.bmp | 26.35dB
21-09-15 04:03:34.940 : ---2--> bird.bmp | 32.46dB
21-09-15 04:03:35.350 : ---3--> butterfly.bmp | 26.62dB
21-09-15 04:03:35.637 : ---4--> head.bmp | 41.31dB
21-09-15 04:03:35.747 : ---5--> woman.bmp | 26.89dB
21-09-15 04:03:35.806 : <epoch: 49, iter: 10,000, Average PSNR : 30.72dB
21-09-15 04:05:06.854 : <epoch: 50, iter: 10,200, lr:1.000e-04> G_loss: 7.127e-02
iter: 10,000 训练生成如下
🟡 Train BSRGAN
batch_size 为 4 时,单卡 GPU占用 9667MiBiter: 10,000次耗时 2小时20分钟【数据集为 DIV2k 800 训练数据】
python main_train_gan.py --opt options/train_bsrgan_x4_gan.json
训练生成效果如下
🎉 附代码+训练数据
该博客记录对应代码 + DIV2K 训练数据集
各位小伙伴按照博文教程或者官方文档,自行构建项目,更能提高和锻炼自己
祝学习顺利,感谢订阅
链接:https://pan.baidu.com/s/1RKjHDxtPuPdFmzEH4oqcOg
提取码:1103
代码目录结构如下
tree -L 3
.
├── data
│ ├── dataset_blindsr.py
│ ├── dataset_dncnn.py
│ ├── dataset_dnpatch.py
│ ├── dataset_dpsr.py
│ ├── dataset_fdncnn.py
│ ├── dataset_ffdnet.py
│ ├── dataset_l.py
│ ├── dataset_plainpatch.py
│ ├── dataset_plain.py
│ ├── dataset_srmd.py
│ ├── dataset_sr.py
│ ├── dataset_usrnet.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── dataset_blindsr.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ └── select_dataset.cpython-37.pyc
│ └── select_dataset.py
├── figs
│ ├── face_01_comparison.png
│ ├── face_04_comparison.png
│ ├── face_08_comparison.png
│ ├── face_10_comparison.png
│ ├── face_12_comparison.png
│ ├── face_13_comparison.png
│ ├── imdn_block.png
│ └── imdn.png
├── kernels
│ ├── kernels_12.mat
│ ├── kernels_bicubicx234.mat
│ ├── k_large_1.png
│ ├── k_large_2.png
│ ├── Levin09.mat
│ └── srmd_pca_matlab.mat
├── LICENSE
├── main_challenge_sr.py
├── main_test_dncnn3_deblocking.py
├── main_test_dncnn.py
├── main_test_dpsr.py
├── main_test_face_enhancement.py
├── main_test_fdncnn.py
├── main_test_ffdnet.py
├── main_test_imdn.py
├── main_test_ircnn_denoiser.py
├── main_test_msrresnet.py
├── main_test_rrdb.py
├── main_test_srmd.py
├── main_test_usrnet.py
├── main_train_dncnn.py
├── main_train_drunet.py
├── main_train_gan.py
├── main_train_psnr.py
├── main_train_usrnet.py
├── models
│ ├── basicblock.py
│ ├── einstein.png
│ ├── loss.py
│ ├── loss_ssim.py
│ ├── model_base.py
│ ├── model_gan.py
│ ├── model_plain2.py
│ ├── model_plain4.py
│ ├── model_plain.py
│ ├── network_discriminator.py
│ ├── network_dncnn.py
│ ├── network_dpsr.py
│ ├── network_faceenhancer.py
│ ├── network_feature.py
│ ├── network_ffdnet.py
│ ├── network_imdn.py
│ ├── network_msrresnet.py
│ ├── network_rrdbnet.py
│ ├── network_rrdb.py
│ ├── network_srmd.py
│ ├── network_swinir.py
│ ├── network_unet.py
│ ├── network_usrnet.py
│ ├── network_usrnet_v1.py
│ ├── op
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d_kernel.cu
│ │ └── upfirdn2d.py
│ ├── __pycache__
│ │ ├── basicblock.cpython-37.pyc
│ │ ├── loss.cpython-37.pyc
│ │ ├── loss_ssim.cpython-37.pyc
│ │ ├── model_base.cpython-37.pyc
│ │ ├── model_gan.cpython-37.pyc
│ │ ├── model_plain.cpython-37.pyc
│ │ ├── network_discriminator.cpython-37.pyc
│ │ ├── network_rrdbnet.cpython-37.pyc
│ │ ├── select_model.cpython-37.pyc
│ │ └── select_network.cpython-37.pyc
│ ├── select_model.py
│ └── select_network.py
├── model_zoo
│ ├── dncnn_25.pth
│ └── README.md
├── options
│ ├── train_bsrgan_x4_gan.json
│ ├── train_bsrgan_x4_psnr.json
│ ├── train_dncnn.json
│ ├── train_dpsr.json
│ ├── train_drunet.json
│ ├── train_fdncnn.json
│ ├── train_ffdnet.json
│ ├── train_imdn.json
│ ├── train_msrresnet_gan.json
│ ├── train_msrresnet_psnr.json
│ ├── train_rrdb_psnr.json
│ ├── train_srmd.json
│ ├── train_swinir_denoising_c_50.json
│ ├── train_swinir_denoising_g_50.json
│ └── train_usrnet.json
├── README.md
├── requirement.txt
├── results
│ └── README.md
├── retinaface
│ ├── data_faces
│ │ ├── config.py
│ │ ├── data_augment.py
│ │ ├── FDDB
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ └── wider_face.py
│ ├── facemodels
│ │ ├── __init__.py
│ │ ├── net.py
│ │ └── retinaface.py
│ ├── layers
│ │ ├── functions
│ │ ├── __init__.py
│ │ ├── modules
│ │ └── __pycache__
│ ├── README.md
│ ├── retinaface_detection.py
│ └── utils_faces
│ ├── box_utils.py
│ ├── __init__.py
│ ├── nms
│ └── timer.py
├── superresolution
│ ├── bsrgan_x4_gan
│ │ ├── images
│ │ ├── models
│ │ ├── options
│ │ └── train.log
│ └── bsrgan_x4_psnr
│ ├── images
│ ├── models
│ ├── options
│ └── train.log
├── testsets
│ ├── README.md
│ ├── real_faces
│ │ ├── face_01.png
│ │ ├── face_02.png
│ │ ├── face_03.png
│ │ ├── face_04.png
│ │ ├── face_05.png
│ │ ├── face_06.png
│ │ ├── face_07.png
│ │ ├── face_08.png
│ │ ├── face_09.png
│ │ ├── face_10.png
│ │ ├── face_11.png
│ │ ├── face_12.png
│ │ ├── face_13.png
│ │ └── face_14.png
│ ├── set12
│ │ ├── 01.png
│ │ ├── 02.png
│ │ ├── 03.png
│ │ ├── 04.png
│ │ ├── 05.png
│ │ ├── 06.png
│ │ ├── 07.png
│ │ ├── 08.png
│ │ ├── 09.png
│ │ ├── 10.png
│ │ ├── 11.png
│ │ └── 12.png
│ └── set5
│ ├── baby.bmp
│ ├── bird.bmp
│ ├── butterfly.bmp
│ ├── head.bmp
│ └── woman.bmp
├── trainsets
│ ├── README.md
│ └── trainH
│ ├── DIV2K_train_HR
│ └── README.md
└── utils
├── __pycache__
│ ├── utils_blindsr.cpython-37.pyc
│ ├── utils_bnorm.cpython-37.pyc
│ ├── utils_dist.cpython-37.pyc
│ ├── utils_image.cpython-37.pyc
│ ├── utils_logger.cpython-37.pyc
│ ├── utils_model.cpython-37.pyc
│ ├── utils_option.cpython-37.pyc
│ └── utils_regularizers.cpython-37.pyc
├── test.bmp
├── test.png
├── utils_alignfaces.py
├── utils_blindsr.py
├── utils_bnorm.py
├── utils_deblur.py
├── utils_dist.py
├── utils_googledownload.py
├── utils_image.py
├── utils_logger.py
├── utils_matconvnet.py
├── utils_mat.py
├── utils_model.py
├── utils_modelsummary.py
├── utils_option.py
├── utils_params.py
├── utils_receptivefield.py
├── utils_regularizers.py
└── utils_sisr.py
40 directories, 189 files
📘 结语
只要环境搭建顺利,该代码正确训练和测试,基本没有问题哈: