import time
import torch
from torch import nn
import torch.nn.utils.prune as prune
from torchsummary import summary
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.re1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(6, 16, 3)
self.re2 = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(1*16*26*26, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = self.re1(x)
x = self.conv2(x)
x = self.re2(x)
x = torch.flatten(x,0)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def pruna_model():
models = torch.load("model.pth")
for n,module in models.named_modules():
if isinstance(module,torch.nn.Conv2d):
prune.random_unstructured(module,name = 'weight', amount =0.3)
prune.remove(module,'weight')
if isinstance(module,nn.Linear):
prune.random_unstructured(module, name='weight', amount=0.3)
prune.remove(module, 'weight')
for n,p in models.named_parameters():
print(n)
print(p)
torch.save(models,"pruna30.pth")
def test_model():
input_data = torch.randn(1,1,30,30)
model = torch.load("pruna30.pth")
pred = model(input_data)
print("result : ",pred)
if __name__=='__main__':
model = LeNet()
torch.save(model,'model.pth')
pruna_model()
test_model()