0
点赞
收藏
分享

微信扫一扫

特征提取后的维度变化

GG_lyf 2022-03-22 阅读 86
python

主要用于在使用一个现有的模型对数据进行特征提取时,进行维度降维,主要使用到线性变化,例如使用resnet50进行特征提取时,可能得到的是2048维特征,但是此时如果进行线性变化,就可得到512维或是128维等特征

class LinearNorm(nn.Module):
    def __init__(self, cfg):
        super(LinearNorm, self).__init__()
        self.fc = nn.Linear(cfg['IN_CHANNELS'], cfg['DIM'])
        self.fc.apply(weights_init_kaiming)

    def forward(self, x):
        x = self.fc(x)
        x = nn.functional.normalize(x, p=2, dim=1)
        return x

cfg['IN_CHANNELS'], cfg['DIM']分别表示输入维度和输出维度

使用方法

backbone为要使用的特征提取模型


head = LinearNorm(head_cfg)
self.model = Sequential(OrderedDict([("backbone", backbone), ("head", head)]))
举报

相关推荐

0 条评论