0
点赞
收藏
分享

微信扫一扫

每天一个数据分析题(一百九十六)

   在迁移学习中,要使用预训练权重。在训练好的网络权重中,使用的是1000类的分类任务。而我们自己的分类任务中没有1000类,需要修改网络。如果将MobileNetV2(num_classes=5),我们的网络在最后输出是5类,与1训练好的1000类不符。加载时就会报错。这时候就要将训练参数中的最后输出1000类的那层删除掉。

    # create model
    net = MobileNetV2(num_classes=5)

    # load pretrain weights
    # download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
    model_weight_path = "./mobilenet_v2.pth"
    assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
    pre_weights = torch.load(model_weight_path, map_location='cpu')

    # delete classifier weights
    pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
    missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)

    # freeze features weights
    for param in net.features.parameters():
        param.requires_grad = False

然后将参数冻结,只训练最后一层参数。

参考视频:如何在Pytorch中载入部分权重_哔哩哔哩_bilibili

举报

相关推荐

0 条评论