主要用于在使用一个现有的模型对数据进行特征提取时,进行维度降维,主要使用到线性变化,例如使用resnet50
进行特征提取时,可能得到的是2048维特征,但是此时如果进行线性变化,就可得到512维或是128维等特征
class LinearNorm(nn.Module):
def __init__(self, cfg):
super(LinearNorm, self).__init__()
self.fc = nn.Linear(cfg['IN_CHANNELS'], cfg['DIM'])
self.fc.apply(weights_init_kaiming)
def forward(self, x):
x = self.fc(x)
x = nn.functional.normalize(x, p=2, dim=1)
return x
cfg['IN_CHANNELS'], cfg['DIM']
分别表示输入维度和输出维度
使用方法
backbone
为要使用的特征提取模型
head = LinearNorm(head_cfg)
self.model = Sequential(OrderedDict([("backbone", backbone), ("head", head)]))