0
点赞
收藏
分享

微信扫一扫

5、模型的剪枝、测试

仲秋花似锦 2022-04-13 阅读 38
深度学习
import time

import torch
from torch import nn
import torch.nn.utils.prune as prune
from torchsummary import summary
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3) #(6,28,28)
        self.re1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(6, 16, 3)#(16,15,15)
        self.re2 = nn.ReLU(inplace=True)
        # self.flate=torch.flatten()
        self.fc1 = nn.Linear(1*16*26*26, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.re1(x)
        x = self.conv2(x)
        # print(x.shape)
        x = self.re2(x)
        # print(x.shape)
        x = torch.flatten(x,0)
        # print(x.shape)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
def pruna_model():
    models = torch.load("model.pth")
    for n,module in models.named_modules():
        if isinstance(module,torch.nn.Conv2d):
            prune.random_unstructured(module,name = 'weight', amount =0.3)#可以选择多种裁剪方式,此处选择了随机裁剪;其中name代表是对哪个参数进行裁剪,如果对偏置进行裁剪则该参数值为'bias';amount是指裁剪比例
            prune.remove(module,'weight')
        # if isinstance(module,nn.Linear):
        #     prune.random_unstructured(module,name='bias',amount=0.9)
        #     prune.remove(module,'bias')
        if isinstance(module,nn.Linear):
            prune.random_unstructured(module, name='weight', amount=0.3)
            prune.remove(module, 'weight')
    for n,p in models.named_parameters():
        print(n)
        print(p)
    torch.save(models,"pruna30.pth")
def test_model():
    input_data = torch.randn(1,1,30,30)
    model = torch.load("pruna30.pth")
    # model = torch.load("model.pth")
    pred = model(input_data)
    print("result : ",pred)
if __name__=='__main__':
    model = LeNet()
    torch.save(model,'model.pth')
    pruna_model()
    test_model()

举报

相关推荐

0 条评论