0
点赞
收藏
分享

微信扫一扫

pytorch获取网络指定层

有时对于一个网络,我们只需要它的部分网络层,当然可以直接在模型定义里把不想要的层删掉,但是我认为这不是最方便的做法,以resnet为例,我是直接复制pytorchresnet的官方实现代码。

model = resnet18()

extractor = nn.Sequential(*list(model.children()[:-2]))
classifier = nn.Sequential(*list(model.children())[-2:])
return extractor, classifier

这样就可以获得网络的指定层。

举报

相关推荐

0 条评论