tushar-semwal / cca.pytorch

CCAs for looking into DNNs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CCA.pytorch

PyTorch implementation of

Now GPU is set as the default device for SVD calculation.

Requirements

  • Python>=3.6
  • PyTorch>=0.4.1
  • torchvision>=0.2.1

To run example.py, you also need

Usage

from cca import CCAHook
device = "cuda" # or "cpu"
hook1 = CCAHook(model, "layer3.0.conv1", svd_device=device)
hook2 = CCAHook(model, "layer3.0.conv2", svd_device=device)
model.eval()
with torch.no_grad():
    model(torch.randn(1200, 3, 224, 224))
hook1.distance(hook2, size=8) # resize to 8x8

Example

python example.py trains ResNet-20 on CIFAR-10 for 100 epochs then measures CCA distance between a trained model and its checkpoints.

Note

While the original SVCCA uses DFT for resizing, we use global average pooling for simplicity.

About

CCAs for looking into DNNs


Languages

Language:Python 100.0%