在pytorch中,如何对标准的预训练模型进行修改以适应三通道以上的输入

阅读 68

2022-06-27

net = resnet50(pretrained=pretrained)
with torch.no_grad():
pretrained_conv1 = net.conv1.weight.clone()
# Assign new conv layer with 4 input channels
net.conv1 = torch.nn.Conv2d(4, 64, 7, 2, 3, bias=False)
# Use same initialization as vanilla ResNet (Don't know if good idea)
torch.nn.init.kaiming_normal_(
net.conv1.weight, mode='fan_out', nonlinearity='relu')
# Re-assign pretraiend weights to first 3 channels
# (assuming alpha channel is last in your input data)
net.conv1.weight[:, :3] = pretrained_conv1

此代码的作用是修改标准预训练模型来适应四通道的输入,前三个通道保持原来的参数,最后一个通道kaiming初始化


精彩评论(0)

0 0 举报