minyoungg / overparam

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Overparam layers

PyTorch linear over-parameterization layers with automatic graph reduction.

Official codebase used in:

The Low-Rank Simplicity Bias in Deep Networks
Minyoung Huh   Hossein Mobahi   Richard Zhang   Brian Cheung   Pulkit Agrawal   Phillip Isola
MIT CSAIL   Google Research   Adobe Research   MIT BCS
TMLR 2023 (arXiv 2021).
[project page] | [paper] | [arXiv]

1. Installation

Developed on

  • Python 3.7 🐍
  • PyTorch 1.7 🔥
> git clone https://github.com/minyoungg/overparam
> cd overparam
> pip install .

2. Usage

The layers work exactly the same as any torch.nn layers.

Getting started

(1a) OverparamLinear layer (equivalence: nn.Linear)

from overparam import OverparamLinear
 
layer = OverparamLinear(16, 32, width=1, depth=2)
x = torch.randn(1, 16)

(1b) OverparamConv2d layer (equivalence: nn.Conv2d)

from overparam import OverparamConv2d
import numpy as np

We can construct 3 Conv2d layers with kernel dimensions of 5x5, 3x3, 1x1

# Same padding
padding = max((np.sum(kernel_sizes) - len(kernel_sizes) + 1) // 2, 0)

layer = OverparamConv2d(2, 4, kernel_sizes=[5, 3, 1], padding, depth=len(kernel_sizes))

# Get the effective kernel size
print(layer.kernel_size)

When kernel_sizes is an integer, all proceeding layers are assumed to have kernel size of 1x1.

(2) Forward computation

# Forward pass (expanded form)
layer.train()
y = layer(x)

When calling eval() the model will automatically reduce the computation graph to its effective single-layer counterpart. Forward pass in eval mode will use the effective weights instead.

# Forward pass (collapsed form) [automatic]
layer.eval()
y = layer(x)

You can access the effective weights as follows:

print(layer.weight)
print(layer.bias)

(3) Automatic conversion

import torchvision.models as models
from overparam.utils import overparameterize

model = models.alexnet() # Replace this with YOUR_PYTORCH_MODEL()
model = overparameterize(model, depth=2)

(4) Batch-norm and Residual connections

We also provide support for batch-norm and linear residual connections.

  • batch-normalization (pseudo-linera layer: linear during eval mode)
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2, 
                        batch_norm=True)
  • residual-connection
# every 2 layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2,
                        residual=True, residual_intervals=2)
  • multiple residual connection
# every modulo [1, 2, 3] layers, a residual connection is added
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2, 
                        residual=True, residual_intervals=[1, 2, 3])
  • batch-norm and residual connection
# mimics `BasicBlock` in ResNets
layer = OverparamConv2d(32, 32, kernel_sizes=3, padding=1, depth=2, 
                        batch_norm=True, residual=True, residual_intervals=2)

3. Cite

@article{huh2023simplicitybias,
    title={The Low-Rank Simplicity Bias in Deep Networks},
    author={Minyoung Huh and Hossein Mobahi and Richard Zhang and Brian Cheung and Pulkit Agrawal and Phillip Isola},
    journal={Transactions on Machine Learning Research},
    issn={2835-8856},
    year={2023},
    url={https://openreview.net/forum?id=bCiNWDmlY2},
}

About


Languages

Language:Python 100.0%