THU-MIG / torch-model-compression

针对pytorch模型的自动化模型结构分析和修改工具集,包含自动分析模型结构的模型压缩算法库

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

resnet50剪枝报错

Wq-dd opened this issue · comments

commented

你好,我在使用resnet18为主干网的retinanet时,自己使用稀疏训练后的模型剪枝会报错,我的做法是:

  1. 首先将训练好的模型计算bn的阈值得到每个bn层应该要剪枝的索引,并保存到一个dict里。
  2. 然后循环1中的dict使用torchprunner去剪枝,会遇到前面的某些层如果剪了过多通道,后面层再剪时会出现索引越界。
  3. 下面是我的部分代码。
        import torchpruner 
        # 创建ONNXGraph对象,绑定需要被剪枝的模型
        self.model.eval()
        graph = torchpruner.ONNXGraph(self.model.cpu())
        ##build ONNX静态图结构,需要指定输入的张量
        graph.build_graph(inputs=(torch.zeros(1, 3, 640, 640),))
        for i, (k, v) in enumerate(mask_dict_for_pruner.items()):
        # 获取conv1模块对应的module
            conv1_module = graph.modules[k]

            # 对前四个通道进行剪枝分析,指定对weight权重进行剪枝,剪枝前四个通道
            # weight权重out_channels对应的通道维度为0
            result = conv1_module.cut_analysis(attribute_name="weight", index=v, dim=0)

            # 剪枝执行模块执行剪枝操作,对模型完成剪枝过程.context变量提供了用于剪枝恢复的上下文
            self.model, context = torchpruner.set_cut(self.model, result)
        # 新的model即为剪枝后的模型
        print(self.model)```

请问是我的用法不对吗还是说这种先计算剪枝的索引再调用torchpruner的方法不对呢

每次剪枝后,model 对象变了,就都要重建 graph、重新执行 build_graph