nn.remat_scan doesn't work with nn.with_partitioning
jameslyon opened this issue · comments
Problem you have encountered:
Using nn.remat_scan
on a layer that defines variables using nn.with_partitioning
fails.
What you expected to happen:
It seems like the default behavior should be to replicate along the scan axis if nothing else is specified.
While I've seen the comments about preferring scan(remat(...)) (https://flax.readthedocs.io/en/latest/faq.html#is-flax-linen-remat-scan-the-same-as-scan-remat), remat_scan is very useful because it provides the right defaults for repeatedly applying layers whose input and output is the same shape, even if the behavior regarding remat is a bit confusing.
Logs, error messages, etc:
PartitioningUnspecifiedError: Trying to transform a Partitioned variable but "partition_name" is not specified in metadata_params
Steps to reproduce:
layer = nn.remat_scan(nn.Dense, lengths=(2,))(
3, kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('replica',))
)
layer.init(jax.random.key(0), jnp.zeros((3,)))
https://colab.research.google.com/drive/1f7f4HdNsROOGXzkWQIriAp5CocGwxQ1y#scrollTo=7P0lwCb9aSLp