Figure 1: Top-1 accuracy tradeoff curve for pruning ResNet50 on the ImageNet classification dataset using a latency cost constraint. Baseline is from PyTorch model hub. Accuracy against FPS speed (left) and FLOPs (right) show the benefit of our method, particularly at high pruning ratios. For FPS, top-right is better. For FLOPs, top-left is better. FPS measured on an NVIDIA TITAN V GPU. See paper for more details
Soft Masking for Cost-Constrained Channel Pruning.
Ryan Humble, Maying Shen, Jorge Albericio Latorre, Eric Darve, and Jose M. Alvarez.
ECCV 2022.
Official Pytorch code repository for the "Soft Masking for Cost-Constrained Channel Pruning" paper presented at ECCV 2022 (contact josea@nvidia.com
for further inquiries).
Structured channel pruning has been shown to significantly accelerate inference time for convolution neural networks (CNNs) on modern hardware, with a relatively minor loss of network accuracy. Recent works permanently zero these channels during training, which we observe to significantly hamper final accuracy, particularly as the fraction of the network being pruned increases. We propose Soft Masking for cost-constrained Channel Pruning (SMCP) to allow pruned channels to adaptively return to the network while simultaneously pruning towards a target cost constraint. By adding a soft mask re-parameterization of the weights and channel pruning from the perspective of removing input channels, we allow gradient updates to previously pruned channels and the opportunity for the channels to later return to the network. We then formulate input channel pruning as a global resource allocation problem. Our method outperforms prior works on both the ImageNet classification and PASCAL VOC detection datasets.
The code sets the memory layout as NHWC (PyTorch's channel_last
as described here). This comes with performance benefits as described in the NVIDIA DL performance documentation.
We adopt a input channel pruning approach, as described in the paper. The importance and the masks are always done along input channels. However, the cost can be done more flexibly, with the channel-doublesided-weight
argument: 1
(the default) is to measure with output channels fixed, 0
is to measure with input channels fixed (like HALP), and numbers in between are a combination.
Soft channel pruning only supports limited architectures. We automatically detect the channel structure of the network (which layers need to be pruned together, which layers can be layer pruned, etc.); this detection logic is only known to work for standard ResNet architectures, MobileNetV1, and SSD512-RN50. Main limitations:
- Group convolutions: Group convolutions are hard to handle, so we only support normal convolutions
groups=1
or depthwise convolutionsgroups=in_channels=out_channels
. - Non-convolution operations: The code only handles convolution, linear, and batch normalization layers as meaningfully interacting with the channels in the network. All other operations are assumed to not change the number of channels nor which dimension of the tensors correpond to the channel (i.e., second dimension for feature maps as the first is the batch dimension). See the further description of the code for more details.
Once training is complete, the slimmed model can be obtained by using the method in model_clean.py
(which uses channel_slimmer.py
internally). This removes the pruned channels and saves the network in its entirety (instead of storing as the state dict; see this for more details). The code does not support saving/loading just the slimmed state dict.
For measuring latency, we can just load the cleaned model back up and measure the forward pass as usual.
This repository uses PyTorch Lightning to handle most of the training intricacies, including (but not limited to):
- GPU and multi-gpu (via DDP) training
- Sync batch norm (for object detection code)
- Automatic mixed precision
- Logging
- Model checkpoints
- Learning rate schedules
- Metric calculation (Accuracy for classification and MAP for detection)
PyTorch Lightning exposes a nice callback mechanism to integrate custom behavior. We implement a DynamicPruning
callback class that integrates our pruning code (which does not depend on PyTorch Lightning) and the PyTorch Lightning training setup.
Code located in folder Classification
Run ResNet50 on ImageNet without pruning:
python -m scmp.classification.image_classifier --dataset Imagenet --data-root=/some/path --gpus=1 --fp16
With dynamic input channel pruning
... --prune --channel-type=Global --channel-ratio=0.3
See full set of command line arguments here.
Code located in folder Object Detection
Run SSD512-RN50 on PascalVOC without pruning:
python -m smcp.detection.object_detection --dataset PascalVOC --data-root=/som/path --gpus=1 --fp16
With dynamic input channel pruning
... --prune --channel-type=Global --channel-ratio=0.3
image_classifier.py
: main training script for CIFAR10/100/ImageNetimage_inference.py
: experimental model cleaning and inference timing script for image classifier modelsdatasets
: folder for CIFAR10/100/ImageNet image classification datasets, written as PyTorch Lightning'sLightningDataModule
s (see their documentation for details)models
: folder for ResNet and MobileNetV1 model definitions
object_detection.py
: main training script for Pascal VOCmetrics.py
: code to take SSD output, convert to detections, and calculate mAP. Includes customSSDDetectionMAP
metric to calculate and log mAP periodically during training.datasets
: folder for Pascal VOC object detection dataset, written as PyTorch Lightning'sLightningDataModule
(see their documentation for details)models
: folder for SSD object detection model definitions
base_pruning.py
: base classBasePruningMethod
for different pruning methodsbn_scaling.py
: re-parameterize the BN weights to perform scaling on themchannel_costing.py
: different costing functions for cost-constrained pruningchannel_pruning.py
: HUGE file for all of dynamic input channel pruningchannel_slimmer.py
: automatic slimming code for channel pruning- Heavily relies on
torch.fx
- Works for both output and input channel pruning
- See restrictions on supported architectures mentioned above
- Heavily relies on
channel_structure.py
: automatic network structure discovery- Heavily relies on
torch.fx
- Distinguishes between channel acting and channel producing nodes
- Channel producing: creates any number of output channels independent of the number of input channels
- Channel acting: number of output channels is a function of the number of input channels AND the layer is stateful
- Heavily relies on
decay.py
: pruned decay (as defined in this paper)dynamic_pruning.py
: connector between anyBasePruningMethod
and the PyTorch lightning training frameworkgroup_knapsack.py
: different solver variants for the multiple-choice knapsack problem described in the paperimportance_accumulator.py
: different ways to accumulate importance over the steps between pruning iterationsimportance.py
: different ways of calculating importancemodel_clean.py
: wrapper for how to clean/slim the modelparameter_masking.py
: re-parameterize the network weights to perform masking; defines several types of maskingresult.py
: base classes for logging the pruning results (# parameters pruned, unpruned, etc.)scheduler.py
: different pruning schedulersshape_propagation.py
:torch.fx.Interpreter
to propagate and store feature map shapes through the network
Please check the LICENSE file. SMCP may be used non-commercially. For business inquiries, please contact researchinquiries@nvidia.com.
@article{Humble2022pruning,
title={Soft Masking for Cost-Constrained Channel Pruning},
author={Humble, Ryan and Shen, Maying and Albericio-Latorre, Jorge and Darve, Eric and Alvarez, Jose M},
journal={ECCV},
year={2022}
}