Improve lifted transforms docstrings
marcvanzee opened this issue · comments
(After feedback from @zhangqiaorjc)
Some nn.scan
arguments are under-documented / make undocumented assumptions. In the example below, the scanning actually happens over the third argument of body_fn
(sub
and layer_in
are ignored):
def body_fn(sub, layer_in, layer_states):
# some code
scan_fn = nn.scan(
body_fn,
in_axes=0, # scan over axis 0 for layer_states only
variable_axes=SCAN_VARIABLE_AXES,
split_rngs=SCAN_SPLIT_RNGS)
However, this is not made clear in the current documentation:
in_axes: Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast
to feed an entire input to each iteration of the scan body.
We should go over all arguments of nn.scan.
and check whethey they make any assumptions that should be mentioned.
#1980 improves some docstrings (also adds a variation of the example above in the nn.scan
docstring). I added placeholders with #1977
in that PR where docstrings should further be improved.
We should also replace self.is_mutable_collection("params")
in the docstrings with self.is_initializing()
.
Here is a check list of all transforms:
- vmap
- jit
- checkpoint
- remat
- remat_scan
- scan
- map_variables
- vjp
- jvp
- while_loop
- cond
- switch
- custom_vjp
The ones that are checked we have recently improved their documentation.