KingJamesSong / DifferentiableSVD

A collection of differentiable SVD methods and ICCV21 "Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?"

Home Page:https://arxiv.org/abs/2105.02498

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Differentiable SVD

Introduction

This repository contains:

  1. The official Pytorch implementation of ICCV21 paper Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?
  2. The official Pytorch implementation of T-PAMI paper On the Eigenvalues of Global Covariance Pooling for Fine-grained Visual Recognition.
  3. A collection of differentiable SVD methods utilized in our paper.

You can also find the presentation of our work via the slides and via the poster.

About the paper

In this paper, we investigate the reason behind why approximate matrix square root calculated via Newton-Schulz iteration outperform the accurate ones computed by SVD from the perspectives of data precision and gradient smoothness. Various remedies for computing smooth SVD gradients are investigated. We also propose a new spectral meta-layer that uses SVD in the forward pass, and Pad'e approximants in the backward propagation to compute the gradients. The results of the so-called SVD-Pad'e achieve state-of-the-art results on ImageNet and FGVC datasets.

Differentiable SVD Methods

As the backward algorithm of SVD is prone to have numerical instability, we implement a variety of end-to-end SVD methods by manipulating the backward algortihms in this repository. They include:

  • SVD-Pad'e: use Pad'e approximants to closely approximate the gradient. It is proposed in our ICCV21 paper.
  • SVD-Taylor: use Taylor polynomial to approximate the smooth gradient. It is proposed in our ICCV21 paper and the TPAMI journal.
  • SVD-PI: use Power Iteration (PI) to approximate the gradients. It is proposed in the NeurIPS19 paper.
  • SVD-Newton: use the gradient of the Newton-Schulz iteration.
  • SVD-Trunc: set a upper limit of the gradient and apply truncation.
  • SVD-TopN: select the Top-N eigenvalues and abandon the rest.
  • SVD-Original: ordinary SVD with gradient overflow check.

In the task of global covaraince pooling, the SVD-Pad'e achieves the best performances. You are free to try other methods in your research.

Implementation and Usage

The codes is modifed on the basis of iSQRT-COV.

See the requirements.txt for the specific required packages.

To train AlexNet on ImageNet, choose a spectral meta-layer in the script and run:

CUDA_VISIBLE_DEVICES=0,1 bash train_alexnet.sh

The pre-trained models of ResNet-50 with SVD-Pad'e is available via Google Drive. You can load the state dict by:

model.load_state_dict(torch.load('pade_resnet50.pth.tar'))

Fine-grained Classification Usage (T-PAMI paper)

PWC

The proposed Scaling Eigen Branch (SEB) is implemented here, and the training scripts are available here.

The pre-trained models are available via Google Drive.

The perturbation-based visualization results are available via Google Drive, and the back-propagation-based visualizations are available via Google Drive. We use torchray to generate backward-based explanations.

Citation

If you think the codes is helpful to your research, please consider citing our paper:

@inproceedings{song2021approximate,
  title={Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?},
  author={Song, Yue and Sebe, Nicu and Wang, Wei},
  booktitle={ICCV},
  year={2021}
}
@article{song2022eigenvalues,
  title={On the Eigenvalues of Global Covariance Pooling for Fine-grained Visual Recognition},
  author={Song, Yue and Sebe, Nicu and Wang, Wei},
  journal={IEEE TPAMI},
  year={2022},
  publisher={IEEE}
}

Contact

If you have any questions or suggestions, please feel free to contact me

yue.song@unitn.it

About

A collection of differentiable SVD methods and ICCV21 "Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?"

https://arxiv.org/abs/2105.02498

License:Apache License 2.0


Languages

Language:Python 86.0%Language:Shell 14.0%