microsoft / mup

maximal update parametrization (µP)

Home Page:https://arxiv.org/abs/2203.03466

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

integration with Flax?

nestordemeure opened this issue · comments

Is there any interest in integrating this work with Flax?

They already have a init function, decoupling parameters initialization from model definition which could make introducing mup fairly plug-and-play.

Plus they relie on optax for their optimizers. As that library has a focus on composability, you might be able to introduce a transformation that takes an optimizer and makes it mup compatible.

Overall, I believe the Flax ecosystem could make mup more easily accessible to people.

Integration with Flax would be fantastic, but neither I nor @edwardjhu are familiar with it. If someone from the Flax team can work with us, we can definitely advise the integration process.

@nestordemeure In case you're interested, I have a first draft of a port to JAX/Haiku here. If you're not attached to FLAX in particular you could use this. You could also probably adapt this design to FLAX if you wanted, since FLAX/Haiku are more similar than FLAX/torch.

Edit: @thegregyang By the way, can you take a look at the plots in the README there? The optimal learning rate stabilizes with width, but it does look like I see better training loss for SP sometimes. Is that indicative of a bug? My coord checks look good, nothing grows with width, output norm (at init) decays with width.

Hey @davisyoshida your repo looks great so far!

For your plot, you'd get better results if you tune the input, output, and hidden learning rates for your small model and scale up from there, sweeping a global lr multiplier on the x-axis (ideally, you tune (lr, init) for all parameter tensors, but these 3 learning rates should be a good practical approximation). In particular, for a fair comparison, the curves for your small model in both SP and muP plots should be the same. Your current plots are just looking at a slice of the HP space (of (lr, init) for all parameter tensors) away from the true optimum.

Ah that makes perfect sense, I'll generate new versions of the figures. Thanks!