网站建设费用如何做账,WordPress菜单过滤器,八角网站建设,科技粉末文章目录 一、 prune模块简介1.1 常用方法1.2 剪枝效果1.3 二、三、四章剪枝测试总结 二、局部剪枝#xff08;Local Pruning#xff09;2.1 结构化剪枝2.1.1 对weight进行随机结构化剪枝#xff08;random_structured#xff09;2.1.2 对weight进行迭代剪枝#xff08;范… 文章目录 一、 prune模块简介1.1 常用方法1.2 剪枝效果1.3 二、三、四章剪枝测试总结 二、局部剪枝Local Pruning2.1 结构化剪枝2.1.1 对weight进行随机结构化剪枝random_structured2.1.2 对weight进行迭代剪枝范数结构化剪枝ln_structured 2.2 非结构化剪枝2.2.1 对bias进行随机非结构化剪枝2.2.2 对多层网络进行范数非结构化剪枝l1_unstructured 2.3 永久化剪枝remove 三、全局剪枝(GLobal pruning)四、自定义剪枝(Custom pruning) 《datawhale2411组队学习之模型压缩技术1模型剪枝上》介绍模型压缩的几种技术模型剪枝基本概念、分类方式、剪枝标准、剪枝频次、剪枝后微调等内容《datawhale11月组队学习 模型压缩技术2PyTorch模型剪枝教程》介绍PyTorch的prune模块具体用法《datawhale11月组队学习 模型压缩技术324结构稀疏化BERT模型》介绍基于模式的剪枝——24结构稀疏化及其在BERT模型上的测试效果 项目地址awesome-compression、在线阅读 一、 prune模块简介 PyTorch教程《Pruning Tutorial》、torch.nn.utils.prune文档 1.1 常用方法
Pytorch在1.4.0版本开始加入了剪枝操作在torch.nn.utils.prune模块中主要有以下剪枝方法
剪枝类型子类型剪枝方法局部剪枝结构化剪枝随机结构化剪枝 (random_structured)范数结构化剪枝 (ln_structured)非结构化剪枝随机非结构化剪枝 (random_unstructured)范数非结构化剪枝 (ln_unstructured)全局剪枝非结构化剪枝全局非结构化剪枝 (global_unstructured)自定义剪枝自定义剪枝 (Custom Pruning)
除此之外模块中还有一些其它方法
方法描述prune.remove(module, name)剪枝永久化prune.apply使用指定的剪枝方法对模块进行剪枝。prune.is_pruned(module)检查给定模块的某个参数是否已被剪枝。prune.custom_from_mask(module, name, mask)基于自定义的掩码进行剪枝用于定义更加细粒度的剪枝策略。
1.2 剪枝效果 参数变化 剪枝前weight 是模型的一个参数意味着它是模型训练时优化的对象可以通过梯度更新通过 optimizer.step() 来更新它的值。剪枝过程中原始权重被保存到新的变量 weight_orig中便于后续访问原始权重。剪枝后weight是剪枝后的权重值通过原始权重和剪枝掩码计算得出但此时不再是参数而是模型的属性一个普通的变量。 掩码存储生成一个名为 weight_mask的剪枝掩码会被保存为模块的一个缓冲区buffer。 前向传递PyTorch 使用 forward_pre_hooks 来确保每次前向传递时都会应用剪枝处理。每个被剪枝的参数都会在模块中添加一个钩子来实现这一操作。
1.3 二、三、四章剪枝测试总结
对weight进行剪枝效果见1.2 章节。对weight进行迭代剪枝相当于把多个剪枝核mask序列化成一个剪枝核 最终只有一个weight_orig和weight_maskhook也被更新。对weight剪枝后再对bias进行剪枝weight_orig和weight_mask不变新增bias_orig和bias_mask新增bias hook。可以对多个模块同时进行剪枝最后使用remove进行剪枝永久化 使用remove函数后 weight_orig 和 bias_orig 被移除剪枝后的weight 和 bias 成为标准的模型参数。经过 remove 操作后剪枝永久化生效。此时剪枝掩码weight_mask 和 hook不再需要named_buffers和_forward_pre_hooks 都被清空。局部剪枝需要根据自己的经验来决定对某一层网络进行剪枝需要对模型有深入了解所以全局剪枝跨不同参数更通用即从整体网络的角度进行剪枝。采用全局剪枝时不同的层被剪掉的百分比可能不同。
parameters_to_prune ((model.conv1, weight),(model.conv2, weight),(model.fc1, weight),(model.fc2, weight))# 应用20%全局剪枝
prune.global_unstructured(parameters_to_prune, pruning_methodprune.L1Unstructured, amount0.2)最终各层剪枝比例为随机的
Sparsity in conv1.weight: 5.33%
Sparsity in conv2.weight: 17.25%
Sparsity in fc1.weight: 22.03%
Sparsity in fc2.weight: 14.67%
Global sparsity: 20.00%自定义剪枝需要通过继承class BasePruningMethod()来定义,其内部有若干方法: call, apply_mask, apply, prune, remove。其中必须实现__init__和compute_mask两个函数才能完成自定义的剪枝规则设定。此外您必须指定要实现的修剪类型 global, structured, and unstructured。
二、局部剪枝Local Pruning 局部剪枝指的是对网络的单个层或局部范围内进行剪枝。其中非结构化剪枝会随机地将一些权重参数变为0结构化剪枝则将某个维度某些通道的权重变成0。 总结一下2.1和2.2的效果
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torchsummary import summary# 1.定义一个经典的LeNet网络
class LeNet(nn.Module):def __init__(self, num_classes10):super(LeNet, self).__init__()self.conv1 nn.Conv2d(in_channels1, out_channels6, kernel_size5)self.conv2 nn.Conv2d(in_channels6, out_channels16, kernel_size5)self.maxpool nn.MaxPool2d(kernel_size2, stride2)self.fc1 nn.Linear(in_features16 * 4 * 4, out_features120)self.fc2 nn.Linear(in_features120, out_features84)self.fc3 nn.Linear(in_features84, out_featuresnum_classes)def forward(self, x):x self.maxpool(F.relu(self.conv1(x)))x self.maxpool(F.relu(self.conv2(x)))x x.view(x.size()[0], -1)x F.relu(self.fc1(x))x F.relu(self.fc2(x))x self.fc3(x)return x
device torch.device(cuda if torch.cuda.is_available() else cpu)
model LeNet().to(devicedevice)# 2.打印模型结构
summary(model, input_size(1, 28, 28))----------------------------------------------------------------Layer (type) Output Shape Param #
Conv2d-1 [-1, 6, 24, 24] 156MaxPool2d-2 [-1, 6, 12, 12] 0Conv2d-3 [-1, 16, 8, 8] 2,416MaxPool2d-4 [-1, 16, 4, 4] 0Linear-5 [-1, 120] 30,840Linear-6 [-1, 84] 10,164Linear-7 [-1, 10] 850Total params: 44,426
Trainable params: 44,426
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.04
Params size (MB): 0.17
Estimated Total Size (MB): 0.22
----------------------------------------------------------------# 3.打印模型的状态字典状态字典里包含了所有的参数
print(model.state_dict().keys())odict_keys([conv1.weight, conv1.bias, conv2.weight, conv2.bias, fc1.weight, fc1.bias, fc2.weight, fc2.bias, fc3.weight, fc3.bias])# 4.打印第一个卷积层的参数
module model.conv1
print(list(module.named_parameters()))[(weight, Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],......[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],[ 0.0418, 0.1033, 0.1615, 0.1822, -0.1586]]]], devicecuda:0,requires_gradTrue)), (bias, Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], devicecuda:0,requires_gradTrue))]# 5.打印module中属性张量named_buffers此时为空列表
print(list(module.named_buffers()))[]2.1 结构化剪枝
2.1.1 对weight进行随机结构化剪枝random_structured 对LeNet的conv1层的weight参数进行随机结构化剪枝其中 amount是一个介于0.0-1.0的float数值,代表比例, 或者一个正整数代表剪裁掉多少个参数.
prune.random_structured(module, nameweight, amount2, dim0)# 1.再次打印模型的状态字典发现conv1层多了weight_orig和weight_mask
print(model.state_dict().keys())odict_keys([conv1.bias, conv1.weight_orig, conv1.weight_mask, conv2.weight, conv2.bias, fc1.weight, fc1.bias, fc2.weight, fc2.bias, fc3.weight, fc3.bias])Conv2d(1, 6, kernel_size(5, 5), stride(1, 1))# 2. 剪枝后原始的weight变成了weight_orig并存放在named_parameters中
print(list(module.named_parameters()))[(bias, Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], devicecuda:0,requires_gradTrue)), (weight_orig, Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],......[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],[ 0.0418, 0.1033, 0.1615, 0.1822, -0.1586]]]], devicecuda:0,requires_gradTrue))]# 3. 剪枝掩码矩阵weight_mask存放在模块的buffer中
print(list(module.named_buffers()))[(weight_mask, tensor([[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]]]))]# 4. 剪枝操作后的weight已经不再是module的参数, 而只是module的一个属性.
print(module.weight)tensor([[[[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],[ 0.0000, -0.0000, 0.0000, -0.0000, 0.0000],[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000],[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000]]],[[[-0.0540, -0.1928, -0.0355, -0.0075, -0.1481],[ 0.0135, 0.0192, 0.0082, -0.0120, -0.0164],[-0.0435, -0.1488, 0.1092, -0.0041, 0.1960],[-0.1045, -0.0136, 0.0398, -0.1286, 0.0617],[-0.0091, 0.0466, 0.1827, 0.1655, 0.0727]]],[[[ 0.1216, -0.0833, -0.1491, -0.1143, 0.0113],[ 0.0452, 0.1662, -0.0425, -0.0904, -0.1235],[ 0.0565, 0.0933, -0.0721, 0.0909, 0.1837],[-0.1739, 0.0263, 0.1339, 0.0648, -0.0382],[-0.1667, 0.1478, 0.0448, -0.0892, 0.0815]]],[[[ 0.0000, 0.0000, 0.0000, -0.0000, 0.0000],[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000],[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000],[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000]]],[[[ 0.1278, 0.1037, -0.0323, -0.1504, 0.1080],[ 0.0266, -0.0996, 0.1499, -0.0845, 0.0609],[-0.0662, -0.1405, -0.0586, -0.0615, -0.0462],[-0.1118, -0.0961, -0.1325, -0.0417, -0.0741],[ 0.1842, -0.1040, -0.1786, -0.0593, 0.0186]]],[[[-0.0889, -0.0737, -0.1655, -0.1708, -0.0988],[-0.1787, 0.1127, 0.0706, -0.0352, 0.1238],[-0.0985, -0.1929, -0.0062, 0.0488, -0.1152],[-0.1659, -0.0448, 0.0821, -0.0956, -0.0262],[ 0.1928, 0.1767, -0.1792, -0.1364, 0.0507]]]],grad_fnMulBackward0)对于每一次剪枝操作PyTorch 会为剪枝的参数如 weight添加一个 forward_pre_hook。这个钩子会在每次进行前向传递计算之前自动应用剪枝掩码即将某些权重置为零这保证了剪枝后的权重在模型计算时被正确地使用。
# 5.打印_forward_pre_hooks
print(module._forward_pre_hooks)OrderedDict([(0, torch.nn.utils.prune.RandomStructured object at 0x7f04012f8ca0)])简单总结就是
weight 不再是参数它变成了一个属性表示剪枝后的权重。weight_orig 保存原始未剪枝的权重。weight_mask 是一个掩码表示哪些权重被剪去了即哪些位置变为零。钩子会保证每次前向传递时weight 会根据 weight_mask 来计算出剪枝后的版本。
2.1.2 对weight进行迭代剪枝范数结构化剪枝ln_structured 一个模型的参数可以执行多次剪枝操作这种操作被称为迭代剪枝Iterative Pruning。上述步骤已经对conv1进行了随机结构化剪枝接下来对其再进行范数结构化剪枝看看会发生什么
# n代表范数这里n2表示l2范数
prune.ln_structured(module, nameweight, amount0.5, n2, dim0)# 再次打印模型参数
print( model state_dict keys:)
print(model.state_dict().keys())
print(**50)print( module named_parameters:)
print(list(module.named_parameters()))
print(**50)print( module named_buffers:)
print(list(module.named_buffers()))
print(**50)print( module weight:)
print(module.weight)
print(**50)print( module _forward_pre_hooks:)
print(module._forward_pre_hooks)model state_dict keys:
odict_keys([conv1.bias, conv1.weight_orig, conv1.weight_mask, conv2.weight, conv2.bias, fc1.weight, fc1.bias, fc2.weight, fc2.bias, fc3.weight, fc3.bias])
**************************************************
module named_parameters: # 原始参数weight_orig不变
...
...
module named_buffers:
[(weight_mask, tensor([[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.],[0., 0., 0., 0., 0.]]],[[[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]]]]))]
**************************************************module weight:......
module _forward_pre_hooks:
OrderedDict([(1, torch.nn.utils.prune.PruningContainer object at 0x7f04c86756d0)])可见迭代剪枝相当于把多个剪枝核序列化成一个剪枝核 新的 mask 矩阵与旧的 mask 矩阵的结合由PruningContainer的compute_mask方法处理最后只有一个weight_orig和weight_mask。 module._forward_pre_hooks是一个用于在模型的前向传播之前执行自定义操作的机制这里记录了执行过的剪枝方法
# 打印剪枝历史
for hook in module._forward_pre_hooks.values():if hook._tensor_name weight: breakprint(list(hook))[torch.nn.utils.prune.RandomStructured object at 0x7f04012f8ca0, torch.nn.utils.prune.LnStructured object at 0x7f04c8675b80]2.2 非结构化剪枝
2.2.1 对bias进行随机非结构化剪枝
此时我们也可以继续对偏置bias进行剪枝看看module的参数、缓冲区、钩子和属性是如何变化的。
prune.random_unstructured(module, namebias, amount1)
# 再次打印模型参数
print( model state_dict keys:)
print(model.state_dict().keys())
print(**50)print( module named_parameters:)
print(list(module.named_parameters()))
print(**50)print( module named_buffers:)
print(list(module.named_buffers()))
print(**50)print( module bias:)
print(module.bias)
print(**50)print( module _forward_pre_hooks:)
print(module._forward_pre_hooks)model state_dict keys:
odict_keys([conv1.weight_orig, conv1.bias_orig, conv1.weight_mask, conv1.bias_mask, conv2.weight, conv2.bias, fc1.weight, fc1.bias, fc2.weight, fc2.bias, fc3.weight, fc3.bias])
**************************************************
# weight_orig不变添加了bias_origmodule named_parameters:
[(weight_orig, Parameter containing:...
, requires_gradTrue)), (bias_orig, Parameter containing:
tensor([-0.0893, -0.1464, -0.1101, -0.0076, 0.1493, -0.0418],requires_gradTrue))]
**************************************************
# weight_mask不变添加了bias_maskmodule named_buffers:
[(weight_mask,
...(bias_mask, tensor([1., 1., 0., 1., 1., 1.]))]
**************************************************module bias:
tensor([-0.0893, -0.1464, -0.0000, -0.0076, 0.1493, -0.0418],grad_fnMulBackward0)
**************************************************module _forward_pre_hooks:
OrderedDict([(1, torch.nn.utils.prune.PruningContainer object at 0x7f04c86756d0), (2, torch.nn.utils.prune.RandomUnstructured object at 0x7f04013a7d30)])对bias进行剪枝后会发现state_dict和named_parameters中不仅仅有了weight_orig也有了bias_orig。在named_buffers中, 也同时出现了weight_mask和bias_mask。最后因为我们在两种参数上进行剪枝因此会生成两个钩子。
2.2.2 对多层网络进行范数非结构化剪枝l1_unstructured 前面介绍了对指定的conv1层的weight和bias进行了不同方法的剪枝那么能不能支持同时对多层网络的特定参数进行剪枝呢
# 对于模型多个模块进行bias剪枝
for n, m in model.named_modules():# 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝if isinstance(m, torch.nn.Conv2d):prune.l1_unstructured(m, namebias, amount0.2)# 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝# elif isinstance(module, torch.nn.Linear):# prune.random_structured(module, nameweight, amount0.4,dim0)# 再次打印模型参数
print( model state_dict keys:)
print(model.state_dict().keys())
print(**50)print( module named_parameters:)
print(list(module.named_parameters()))
print(**50)print( module named_buffers:)
print(list(module.named_buffers()))
print(**50)print( module weight:)
print(module.weight)
print(**50)print( module bias:)
print(module.bias)
print(**50)print( module _forward_pre_hooks:)
print(module._forward_pre_hooks)model state_dict keys:
odict_keys([conv1.weight_orig, conv1.bias_orig, conv1.weight_mask, conv1.bias_mask, conv2.weight, conv2.bias_orig, conv2.bias_mask, fc1.weight, fc1.bias, fc2.weight, fc2.bias, fc3.weight, fc3.bias])
**************************************************module named_parameters:[(weight_orig, Parameter containing:...(bias_orig, Parameter containing:...
**************************************************
# # weight_mask不变bias_mask更新
module named_buffers:
[(weight_mask, ...
(bias_mask, tensor([1., 1., 0., 0., 1., 1.]))]
**************************************************
# module weight不变
module weight:...
**************************************************
module bias:
tensor([-0.0893, -0.1464, -0.0000, -0.0000, 0.1493, -0.0418],grad_fnMulBackward0)
**************************************************
module _forward_pre_hooks:
OrderedDict([(1, torch.nn.utils.prune.PruningContainer object at 0x7f04c86756d0), (3, torch.nn.utils.prune.PruningContainer object at 0x7f04010c1100)])2.3 永久化剪枝remove
接下来对模型的weight和bias参数进行永久化剪枝操作prune.remove。
# 对module的weight执行剪枝永久化操作remove
for n, m in model.named_modules():if isinstance(m, torch.nn.Conv2d):prune.remove(m, bias)# 对conv1的weight执行剪枝永久化操作remove
prune.remove(module, weight)
print(**50)# 将剪枝后的模型的状态字典打印出来
print( model state_dict keys:)
print(model.state_dict().keys())
print(**50)# 再次打印模型参数
print( model named_parameters:)
print(list(module.named_parameters()))
print(**50)# 再次打印模型mask buffers参数
print( model named_buffers:)
print(list(module.named_buffers()))
print(**50)# 再次打印模型的_forward_pre_hooks
print( model forward_pre_hooks:)
print(module._forward_pre_hooks)**************************************************model state_dict keys:
odict_keys([conv1.bias, conv1.weight, conv2.weight, conv2.bias, fc1.weight, fc1.bias, fc2.weight, fc2.bias, fc3.weight, fc3.bias])
**************************************************model named_parameters:
[(bias, Parameter containing:
tensor([-0.0893, -0.1464, -0.0000, -0.0000, 0.1493, -0.0418],requires_gradTrue)), (weight, Parameter containing:
tensor([[[[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],[ 0.0000, -0.0000, 0.0000, -0.0000, 0.0000],[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000],[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000]]],[[[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000],[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000],[-0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],[[[ 0.1216, -0.0833, -0.1491, -0.1143, 0.0113],[ 0.0452, 0.1662, -0.0425, -0.0904, -0.1235],[ 0.0565, 0.0933, -0.0721, 0.0909, 0.1837],[-0.1739, 0.0263, 0.1339, 0.0648, -0.0382],[-0.1667, 0.1478, 0.0448, -0.0892, 0.0815]]],[[[ 0.0000, 0.0000, 0.0000, -0.0000, 0.0000],[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000],[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],[-0.0000, -0.0000, 0.0000, -0.0000, 0.0000],[ 0.0000, -0.0000, 0.0000, -0.0000, -0.0000]]],[[[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],[ 0.0000, -0.0000, 0.0000, -0.0000, 0.0000],[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000],[ 0.0000, -0.0000, -0.0000, -0.0000, 0.0000]]],[[[-0.0889, -0.0737, -0.1655, -0.1708, -0.0988],[-0.1787, 0.1127, 0.0706, -0.0352, 0.1238],[-0.0985, -0.1929, -0.0062, 0.0488, -0.1152],[-0.1659, -0.0448, 0.0821, -0.0956, -0.0262],[ 0.1928, 0.1767, -0.1792, -0.1364, 0.0507]]]], requires_gradTrue))]
**************************************************model named_buffers:
[]
**************************************************model forward_pre_hooks:
OrderedDict()可见执行remove操作后
weight_orig 和 bias_orig 被移除剪枝后的weight 和 bias 成为标准的模型参数。经过 remove 操作后剪枝永久化生效。剪枝掩码weight_mask 和 bias_mask不再需要named_buffers被清空_forward_pre_hooks 也被清空由于剪枝后的权重和偏置将直接反映在最终模型中所以无须再借助外部的掩码或钩子函数来维护剪枝过程。
三、全局剪枝(GLobal pruning) 前面已经介绍了局部剪枝的四种方法但这很大程度上需要根据自己的经验来决定对某一层网络进行剪枝。 更通用的剪枝策略是采用全局剪枝即从整体网络的角度进行剪枝。采用全局剪枝时不同的层被剪掉的百分比可能不同。
model LeNet().to(devicedevice)# 1.打印初始化模型的状态字典
print(model.state_dict().keys())
print(**50)# 2.构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
parameters_to_prune ((model.conv1, weight),(model.conv2, weight),(model.fc1, weight),(model.fc2, weight))# 3. 全局剪枝
prune.global_unstructured(parameters_to_prune, pruning_methodprune.L1Unstructured, amount0.2)# 4. 打印剪枝后模型的状态字典
print(model.state_dict().keys())odict_keys([conv1.weight, conv1.bias, conv2.weight, conv2.bias, fc1.weight, fc1.bias, fc2.weight, fc2.bias, fc3.weight, fc3.bias])
**************************************************
odict_keys([conv1.bias, conv1.weight_orig, conv1.weight_mask, conv2.bias, conv2.weight_orig, conv2.weight_mask, fc1.bias, fc1.weight_orig, fc1.weight_mask, fc2.bias, fc2.weight_orig, fc2.weight_mask, fc3.weight, fc3.bias])打印一下各层被剪枝的比例
print(Sparsity in conv1.weight: {:.2f}%.format(100. * float(torch.sum(model.conv1.weight 0))/ float(model.conv1.weight.nelement())))print(Sparsity in conv2.weight: {:.2f}%.format(100. * float(torch.sum(model.conv2.weight 0))/ float(model.conv2.weight.nelement())))print(Sparsity in fc1.weight: {:.2f}%.format(100. * float(torch.sum(model.fc1.weight 0))/ float(model.fc1.weight.nelement())))print(Sparsity in fc2.weight: {:.2f}%.format(100. * float(torch.sum(model.fc2.weight 0))/ float(model.fc2.weight.nelement())))print(Global sparsity: {:.2f}%.format(100. * float(torch.sum(model.conv1.weight 0) torch.sum(model.conv2.weight 0) torch.sum(model.fc1.weight 0) torch.sum(model.fc2.weight 0))/ float(model.conv1.weight.nelement() model.conv2.weight.nelement() model.fc1.weight.nelement() model.fc2.weight.nelement())))
Sparsity in conv1.weight: 5.33%
Sparsity in conv2.weight: 17.25%
Sparsity in fc1.weight: 22.03%
Sparsity in fc2.weight: 14.67%
Global sparsity: 20.00%四、自定义剪枝(Custom pruning) 剪枝模型通过继承class BasePruningMethod()来执行剪枝, 内部有若干方法: call, apply_mask, apply, prune, remove等等。其中必须实现__init__构造函数和compute_mask两个函数才能完成自定义的剪枝规则设定。 此外您必须指定要实现的修剪类型 global, structured, and unstructured。
# 自定义剪枝方法的类, 一定要继承prune.BasePruningMethod
class custom_prune(prune.BasePruningMethod):# 指定此技术实现的修剪类型支持的选项为global、 structured和unstructuredPRUNING_TYPE unstructured# 内部实现compute_mask函数, 定义剪枝规则, 本质上就是如何去mask掉权重参数def compute_mask(self, t, default_mask):mask default_mask.clone()# 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数的50%被mask掉mask.view(-1)[::2] 0return mask# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def custome_unstructured_pruning(module, name):custom_prune.apply(module, name)return moduleimport time
# 实例化模型类
model LeNet().to(devicedevice)start time.time()
# 调用自定义剪枝方法的函数, 对model中的第1个全连接层fc1中的偏置bias执行自定义剪枝
custome_unstructured_pruning(model.fc1, namebias)# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc1.bias_mask)# 打印一下自定义剪枝的耗时
duration time.time() - start
print(duration * 1000, ms)tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
5.576610565185547 ms