google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Multiple Inheritance -> doesn't recognize as Module throws ValueError: parent must be None, Module or Scope

avital opened this issue · comments

Discussed in #1390

Originally posted by SauravMaheshkar June 26, 2021
I'm working on a Flax implementation for ProteinBERT: A universal deep-learning model of protein sequence and function. My work so far is in SauravMaheshkar/ProteinBERT.

I've made a simple test.py to check instantiation using the .init() function. My test script is as follows :

from proteinbert import ProteinBERT
import jax
from jax import random


def test():

    seq = jax.random.randint(
        key=random.PRNGKey(0), minval=0, maxval=21, shape=(2, 2048)
    )
    annotation = jax.random.randint(
        key=random.PRNGKey(0), minval=0, maxval=1, shape=(2, 8943)
    )

    init_rngs = {"params": random.PRNGKey(0), "layers": random.PRNGKey(1)}

    ProteinBERT().init(init_rngs, seq, annotation)


if __name__ == "__main__":
    test()

And I've been getting this error message

Error Message
Traceback (most recent call last):
  File "/Users/sauravmaheshkar/github/protein_bert/test.py", line 21, in <module>
    test()
  File "/Users/sauravmaheshkar/github/protein_bert/test.py", line 17, in test
    ProteinBERT().init(init_rngs, seq, annotation)
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 1000, in init
    method=method, mutable=mutable, **kwargs)
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 969, in init_with_output
    {}, *args, rngs=rngs, method=method, mutable=mutable, **kwargs)
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 939, in apply
    )(variables, *args, **kwargs, rngs=rngs)
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/core/scope.py", line 687, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 1178, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 266, in wrapped_module_method
    self._try_setup()
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 679, in _try_setup
    self.setup()
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/Users/sauravmaheshkar/github/protein_bert/proteinbert/model.py", line 82, in setup
    Reduce("b n d -> b d", "mean"),
  File "<string>", line 5, in __init__
  File "/Users/sauravmaheshkar/opt/anaconda3/envs/proteinbert/lib/python3.7/site-packages/flax/linen/module.py", line 599, in __post_init__
    raise ValueError("parent must be None, Module or Scope")
ValueError: parent must be None, Module or Scope

The problem lies in the Reduce defined in proteinbert/utils.py class which is defined as follows:

class Reduce(ReduceMixin, nn.Module):
    """
    Flax Module to act as a Reduce layer (from einops)
    """
    def __call__(self, input):
        return self._apply_recipe(input)

The idea is to create a Reduce layer/Module for flax which performs the reduce operation from einops. Although the module inherits from flax.linen.Module it still throws a ValueError.

Any help would be much appreciated 😊.

@marcvanzee will be investigating this.

A few guiding questions:

  1. Why would einops be implemented as a Module instead of just a function?
  2. Why is multiple inheritance needed here?
  3. Regardless, this error shouldn't happen. So even if we answer (1) and (2) in a way that means there's a workaround, we should still fix this bug.

Just noticing this issue for the first time... I've seen similarly weird issues with Mixins and Flax resolved in the past by simply changing the order of the multiple inheritance - e.g. class Reduce(nn.Module, ReduceMixin): to put nn.Module first. I'm not 100% sure this is the same kind of issue that I've seen before w. mixins, but I'd certainly be curious if that would have fixed the issue...

I was playing around with mixins to see how they interact with Module.

See experiments Following case works:
import flax.linen as nn
import jax.numpy as jnp
import jax

class Mixin:
    def __call__(self, x):
        return self.dense(x)

class MyModule(nn.Module, Mixin):
    def setup(self):
        self.dense = nn.Dense(2)
    
module_a = MyModule()
variables = module_a.init(jax.random.PRNGKey(0), jnp.ones((1, 1)))

However, passing setup to Mixin fails:

class Mixin:
    def setup(self):
        self.dense = nn.Dense(2)
    
    def __call__(self, x):
        return self.dense(x)

class MyModule(nn.Module, Mixin):
    pass

# AttributeError: "MyModule" object has no attribute "dense"

This is again fixed if Mixin is set as the first parent:

class Mixin:
    def setup(self):
        self.dense = nn.Dense(2)
    
    def __call__(self, x):
        return self.dense(x)

class MyModule(Mixin, nn.Module):
    pass

Also, you cannot define compact methods on mixins (this is probably expected?):

class Mixin:
    @nn.compact
    def __call__(self, x):
        return nn.Dense(2)(x)

class MyModule(nn.Module, Mixin): # swapping doesn't help
    pass

Discussion

Based on these experiments the only insight I see is: don't define scope-dependent operations (compact, self.param/variable) in inside mixins as their methods will not be wrapped appropriately. Not sure if there is a way to properly wrap mixin methods in __init_subclass__, either _get_local_method_names is not detecting them or they are not available when __init_subclass__ is called.