pytorch获取网络指定层

阅读 39

2022-01-31

有时对于一个网络,我们只需要它的部分网络层,当然可以直接在模型定义里把不想要的层删掉,但是我认为这不是最方便的做法,以resnet为例,我是直接复制pytorchresnet的官方实现代码。

model = resnet18()

extractor = nn.Sequential(*list(model.children()[:-2]))
classifier = nn.Sequential(*list(model.children())[-2:])
return extractor, classifier

这样就可以获得网络的指定层。

相关推荐

精彩评论(0)

0 0 举报