resnet50剪枝报错
Wq-dd opened this issue · comments
Wq-dd commented
你好,我在使用resnet18为主干网的retinanet时,自己使用稀疏训练后的模型剪枝会报错,我的做法是:
- 首先将训练好的模型计算bn的阈值得到每个bn层应该要剪枝的索引,并保存到一个dict里。
- 然后循环1中的dict使用torchprunner去剪枝,会遇到前面的某些层如果剪了过多通道,后面层再剪时会出现索引越界。
- 下面是我的部分代码。
import torchpruner
# 创建ONNXGraph对象,绑定需要被剪枝的模型
self.model.eval()
graph = torchpruner.ONNXGraph(self.model.cpu())
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(1, 3, 640, 640),))
for i, (k, v) in enumerate(mask_dict_for_pruner.items()):
# 获取conv1模块对应的module
conv1_module = graph.modules[k]
# 对前四个通道进行剪枝分析,指定对weight权重进行剪枝,剪枝前四个通道
# weight权重out_channels对应的通道维度为0
result = conv1_module.cut_analysis(attribute_name="weight", index=v, dim=0)
# 剪枝执行模块执行剪枝操作,对模型完成剪枝过程.context变量提供了用于剪枝恢复的上下文
self.model, context = torchpruner.set_cut(self.model, result)
# 新的model即为剪枝后的模型
print(self.model)```
请问是我的用法不对吗?还是说这种先计算剪枝的索引再调用torchpruner的方法不对呢?
Dahan Gong commented
每次剪枝后,model 对象变了,就都要重建 graph、重新执行 build_graph