google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

nn.remat_scan doesn't work with nn.with_partitioning

jameslyon opened this issue · comments

Problem you have encountered:

Using nn.remat_scan on a layer that defines variables using nn.with_partitioning fails.

What you expected to happen:

It seems like the default behavior should be to replicate along the scan axis if nothing else is specified.

While I've seen the comments about preferring scan(remat(...)) (https://flax.readthedocs.io/en/latest/faq.html#is-flax-linen-remat-scan-the-same-as-scan-remat), remat_scan is very useful because it provides the right defaults for repeatedly applying layers whose input and output is the same shape, even if the behavior regarding remat is a bit confusing.

Logs, error messages, etc:

PartitioningUnspecifiedError: Trying to transform a Partitioned variable but "partition_name" is not specified in metadata_params

Steps to reproduce:

layer = nn.remat_scan(nn.Dense, lengths=(2,))(
    3, kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('replica',))
)

layer.init(jax.random.key(0), jnp.zeros((3,)))

https://colab.research.google.com/drive/1f7f4HdNsROOGXzkWQIriAp5CocGwxQ1y#scrollTo=7P0lwCb9aSLp