probabilists / zuko

Normalizing flows in PyTorch

Home Page:https://zuko.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Initialization takes very long

michaeldeistler opened this issue · comments

Description

Thanks a lot for the package, the API is lovely! Unfortunately, instantiating flows in high-D takes very long. Is there a known reason for this and is there a simple trick to speed it up? Thanks a lot for your help!

Reproduce

import torch
import zuko

dim_density = 1000

x = torch.randn(dim_density)
y = torch.randn(5)

flow = zuko.flows.NSF(dim_density, 5, transforms=1, hidden_features=[128] * 2)

This code takes about 3 minutes to run. The following is instant in nflows:

from nflows import transforms, distributions, flows

dim_density = 1000

transform = transforms.CompositeTransform([
    transforms.MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
        features=dim_density, 
        hidden_features=128,
        tails="linear",
        tail_bound=3.0,
    ),
    transforms.RandomPermutation(features=dim_density)
])
base_distribution = distributions.StandardNormal(shape=[dim_density])
flow = flows.Flow(transform=transform, distribution=base_distribution)

Environment

  • Zuko version: cutting edge github version
  • PyTorch version: '1.12.0+cu102'
  • Python version: 3.9.5
  • OS: Ubuntu 20.04

Thank you for the report! That's an interesting bug 🤨 It will need some debugging. My guess is the MaskedMLP initialization for such a large adjacency matrix. 1000 features with 8 spline bins means 24000 parameters for each transformation so a 24000 x 24000 precedence matrix.

I made a PR to fix the issue. It seems to work well and passes the tests. I need to check a few things before pushing, but until then you can use the patch branch for your tests.

Fantastic, thanks! No rush from our side!