google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improve documentation of key concepts behind initialization in nn.Module

billmark opened this issue · comments

This bug is for improving documentation for a particular aspect of Flax. Specifically I (and several other people I've talked to), find the existing documentation to be very confusing when we try to understand the fundamental concepts behind how nn.Module initialization works, particularly when nn.compat is not used.

The following email from Flax engineer Anselm Levskaya was very helpful (Thanks, Anselm!). This bug is to request that the following information be cleaned up and added to the official documentation.


So in the situation where you have an nn.Module with a setup() and two different methods like encode() or decode():

The point of setup() in the presence of multiple methods is to provide a unique definition of what the variables/submodules in a module are. All variable and submodule definitions are forced to occur inside setup() which is always executed before any call to encode, decode, etc.

Now, those submodules are initialized lazily themselves! So to be clear, you -can- cook up a module to have two methods that use those shared variables/submodules differently, and if you jitted an init function only calling one or the other method, the resulting variables from the jitted init would be different. We enforce a unique and unambiguous set of variable/submodule definitions, but we intentionally don't constrain the user to evaluate the Module in a particular way to force initialization across the entire tree of variables.

Now, in practice, the vastly common cases of multiple-method Modules are:

  • wanting to call different "subpieces" of a larger computation, and typically the entire computation is used to initialize the variables. (again e.g. transformer encoder-decoder - subpieces used in inference loop, other autoencoders, etc.)
  • having two functions that use exactly the same variables in different ways - e.g. an Embedding that can lookup (index) an embedding table or 'attend' (matmul against vector) to the embedding table

There have been a handful of users who wanted to smash two disjunct things into the same module for convenience and who wanted a single consistent global init - in that case we typically recommend writing a forcing method that takes inputs for shape-inference and then triggers initialization across all defined submodules when used as the method to init the module. But this is actually pretty uncommon, and it may be cleaner to factor things like this into separate modules.

Backing up, you might ask -why- build a lazy system like this at all? Why not just use eager initialization on construction as in pytorch?

The point of this laziness is to 1) enable shape inference based on the inputs 2) build a system where 'init' and 'apply' are basically the same codepath, and where a -single- transform specification can transform both initialization and application codepaths and 3) avoid double-recursion in the case of nested transforms during the tracing of each module, where each module has to be treated as a function by jax.

Flax was designed around the idea that transforms should be first-class citizens, since that's really the unique thing about jax, and not impossibly awkward to use. One can debate whether that's the right priority, but lazy initialization is a consequence of that.

I have a question that appears to be related to the above. I'm attempting to implement a multi-task model in flax, and running into trouble due to this lazy initialization. For a simple conceptual example say I have three types of entities: A, B, and C. I would like to learn embeddings for each of these and I have data about the relationships between A and B and A and C. I have reason to believe that a shared embedding for A across these tasks should improve performance. What is the Flax-like way to implement such parameter sharing? First, I tried initializing all embeddings in setup and defining call with a mode argument specifying whether I am providing B or C in addition to A, but ran into the trouble with only parameters falling along the path specified in the init call being defined. I can see a number of ways to proceed:

  • Define a call mode as in above that just uses all the parameters somehow and only use this to init the params
  • Define the embeddings and their downstream processing as entirely separate modules and manage their grads/states separately during training.
  • Define a module for each task and share parameters by making frequent use of unfreeze/update/freeze on the common embeddings.

But each of the above seems unappealing for various reasons. Is there a clean solution that I'm missing? Things like this seem like just the type of experimental workflows Jax/Flax should be good at, so I'm keen to learn a good solution and it could be a good point for future documentation as well.

commented

but ran into the trouble with only parameters falling along the path specified in the init call being defined

This is indeed a downside of lazy initialization at times. The alternative case is that whether a certain Module is used or not is fixed by a hyper-parameter to the module and in this case only initialization what you use is a feature not a bug.

In your case I think you should indeed make a call that iterates over all modes during init. Because initializing for one or all modes are both valid initializations but in your case you need the second.

@billmark: @jheek is working on a Module lifecycle note that hopefully explains some of your confusions better! See #1964.