torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:
https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
直接上代码
首先建立模型网络:
import torch
import torch.nn as nn
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimpleNet(nn.Module):
def __init__(self, num_classes=10):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
def forward(self, input):
output = self.conv1(input)
output = nn.ReLU()(output)
output = self.conv2(output)
output = nn.ReLU()(output)
output = self.pool(output)
output = self.conv3(output)
output = nn.ReLU()(output)
output = self.conv4(output)
output = nn.ReLU()(output)
output = output.view(-1, 16 * 16 * 24)
output = self.fc(output)
return output
model = SimpleNet().to(device=device)
看一下模型的 summary
summary(model,

本文介绍了如何使用torch.nn.utils.prune对PyTorch模型进行L1结构化剪枝,展示了剪枝后的模型结构变化,以及实际运行中可能的运算影响。尽管剪枝减少了参数数量,但并未明显提高运算效率,更多是理论上的轻量化。
最低0.47元/天 解锁文章
1180

被折叠的 条评论
为什么被折叠?



