有时对于一个网络,我们只需要它的部分网络层,当然可以直接在模型定义里把不想要的层删掉,但是我认为这不是最方便的做法,以resnet
为例,我是直接复制pytorch
对resnet
的官方实现代码。
model = resnet18()
extractor = nn.Sequential(*list(model.children()[:-2]))
classifier = nn.Sequential(*list(model.children())[-2:])
return extractor, classifier
这样就可以获得网络的指定层。