google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improve lifted transforms docstrings

marcvanzee opened this issue · comments

(After feedback from @zhangqiaorjc)

Some nn.scan arguments are under-documented / make undocumented assumptions. In the example below, the scanning actually happens over the third argument of body_fn (sub and layer_in are ignored):

def body_fn(sub, layer_in, layer_states):
  # some code

scan_fn = nn.scan(
    body_fn,
    in_axes=0,  # scan over axis 0 for layer_states only
    variable_axes=SCAN_VARIABLE_AXES,
    split_rngs=SCAN_SPLIT_RNGS)

However, this is not made clear in the current documentation:

in_axes: Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.

We should go over all arguments of nn.scan. and check whethey they make any assumptions that should be mentioned.

#1980 improves some docstrings (also adds a variation of the example above in the nn.scan docstring). I added placeholders with #1977 in that PR where docstrings should further be improved.

We should also replace self.is_mutable_collection("params") in the docstrings with self.is_initializing().

Here is a check list of all transforms:

  • vmap
  • jit
  • checkpoint
  • remat
  • remat_scan
  • scan
  • map_variables
  • vjp
  • jvp
  • while_loop
  • cond
  • switch
  • custom_vjp

The ones that are checked we have recently improved their documentation.