google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Start using positional-only parameters in Linen

marcvanzee opened this issue · comments

Some functions in our public API use positional arguments and keyword arguments in a way that is somewhat error prone. Two examples:

# Example 1 (module.py)
def variable(self, col: str, name: str,
             init_fn: Optional[Callable[..., Any]] = None,
             *init_args) -> Variable:
  ...

Calling this function with variable('a', 'b', 'arg1', init_fn=fn) will give two vaues to init_fn, which is wrong.

# Example 2 (module.py)
def init(self,
         rngs: Union[PRNGKey, RNGSequences],
         *args,
         method: Optional[Callable[..., Any]] = None,
         mutable: CollectionFilter = DenyList('intermediates'),
         **kwargs) -> FrozenVariableDict:

If users pass rng= as a kwarg it will break while they probably wanted to forward a rng keyword arg to their __call__ function.

We could resolve this by start using Python's 3.8 "positonal-only arguments: https://docs.python.org/3/whatsnew/3.8.html#positional-only-parameters

So that will look as follows:

def variable(self, col: str, name: str,
             init_fn: Optional[Callable[..., Any]] = None,
             /,
             *init_args) -> Variable:

Now variable('a', 'b', 'arg1', init_fn=fn) will give error TypeError: variable() got some positional-only arguments passed as keyword arguments: 'init_fn'.

This feature is available starting Python 3.8, but we are currently using Python 3.7 in our Github Action.

Upgrading to Python 3.8 is currently blocked on the following things:

  • Public Colab is using Python 3.7.13
  • HuggingFace is tested on Python 3.6.0
  • We should also make sure our Cloud setup uses at least Python 3.8.

Update: this is no longer blocked since we now only run Python 3.8-3.10.