google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature request: ability to apply stop gradient to some parameters

NeilGirdhar opened this issue · comments

To motivate this feature request, I'll explain what I'm currently doing (without Flax), and the other solutions I've considered. Then I'll suggest some Flax solution.

Problem:

In the process of inferring one of my modules, I need to mask varying subsets of the weights with stop-gradient in a single function:

(click for long snippet)

def infer(encoding: EncodingElement,
          observation: PoolingMessage,
          prediction: PredictionMessage,
          rng: Generator,
          weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
    sampler_rng, code_rng = rng.split()
    # Create four copies of the weights:
    # * weights_sg has stop_gradient applied to all weights, and
    # * the other three have stop_gradient applied to different partitions of
    # the weights.
    weights_sg, weights_g, weights_c, weights_e = _stop_gradient_on_some_weights(weights)

    # Inference ------------------------------------------------------------------------------------
    # This function uses weights_sg so this calculation won't poison the weight cotangents.
    # However, cotangents still propagate back to observation.
    code_message = encoding.code_message(observation, weights_sg)

    # GLN loss -------------------------------------------------------------------------------------
    # The scan parameters depend on weights_g.
    encoding_parameters_g = SamplerParameters(observation, prediction, weights_g)
    # This use of stop_gradient prevents the cotangents from propagating back from the scan through
    # to the observation.
    initial_code_message = stop_gradient(code_message)
    # This class manages an iterated function (a scan)
    sampler = EncodingSampler(encoding)
    sampler_iterations = encoding.inference_parameters.sampler_iterations
    initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
    # This is an extremely computationally expensive scan.
    sampler_state, sampler_trajectory = sampler.sample_trajectory(
        encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
    # We calculate a GLN loss, which can only affects the subset of weights in weights_g.
    gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
                / sampler_iterations)
    iterative_code_message = sampler_state.code_message

    # Code loss ------------------------------------------------------------------------------------
    # The code loss trains the code and selection links to produce a code message that predicts the
    # code message that we inferred by iteration.
    # This is the same code_message function as above, but uses weights_c.
    c_code_message = encoding.code_message(observation, weights_c, rng=code_rng,
                                           use_code_signal_noise=True)
    # When this loss is minimized only the weights that are not marked stop-gradient in weights_c
    # are adjusted.  Cotangents are also blocked from poisoning the scan by applying stop_gradient
    # to its outputs.
    code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
                                            - c_code_message.log_presence))
    code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
                                         - c_code_message.code_value))
    code_loss = code_presence_loss + code_value_loss

    # Snipped a lot of code here that uses weights_e and produces output primals.

    return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss)

# Below is the code that uses Haiku to partition the weights and apply stop gradient to different
# partitions.
_module_classes = [{'gln'}, {'code_value', 'code_presence'}, {'explanation'}]

def _module_predicate(module_name: str,
                      name: str,
                      value: Array) -> int:
    prefix = module_name.split('/')[0]
    for i, prefix_set in enumerate(_module_classes):
        if prefix in prefix_set:
            return i
    raise RuntimeError

# I was using Haiku before, but I'll have to port this to Flax somehow.
def _partition_by_module(weights: FrozenVariableDict) -> tuple[FrozenVariableDict, ...]:
    return hk.data_structures.partition_n(_module_predicate,  # type: ignore[arg-type]
                                          weights, len(_module_classes))

def _stop_gradient_on_some_weights(weights: FrozenVariableDict) -> list[FrozenVariableDict]:
    weights_sg = stop_gradient(weights)
    weights_p = _partition_by_module(weights)
    weights_sg_p = _partition_by_module(weights_sg)

    return ([weights_sg]
            + [hk.data_structures.merge(weights_pi,
                                        *[weights_sg_pi
                                          for j, weights_sg_pi in enumerate(weights_sg_p)
                                          if i != j])
               for i, weights_pi in enumerate(weights_p)])

Non-solution:

I discussed this with @cgarciae and brainstormed a non-solution: I could try to put the "C", "G", and "E" weights into different "collections". And then run inference three times. This doesn't work because:

  • the scan is very expensive and I don't want to run it three times,
  • the scan can't easily be hoisted out because the scan itself depends on the G collection of weights, and
  • all sorts of intermediate values that are created between the different parts of the function.

Possible Flax interface:

We came up with two Flax interfaces that might work.

I suggested some kind of context manager flax.linen.stop_gradient:

(click for long snippet)

def infer(encoding: EncodingElement,
          observation: PoolingMessage,
          prediction: PredictionMessage,
          rng: Generator,
          weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
    sampler_rng, code_rng = rng.split()

    # Inference ------------------------------------------------------------------------------------
    # This function uses weights_sg so this calculation won't poison the weight cotangents.
    # However, cotangents still propagate back to observation.
    with nn.stop_gradient(lambda c: True):
        code_message = encoding.code_message(observation)

    # GLN loss -------------------------------------------------------------------------------------
    encoding_parameters_g = SamplerParameters(observation, prediction)
    # This use of stop_gradient prevents the cotangents from propagating back from the scan through
    # to the observation.
    initial_code_message = stop_gradient(code_message)
    sampler = EncodingSampler(encoding)
    sampler_iterations = encoding.inference_parameters.sampler_iterations
    initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
    # The scan parameters depend on weights_g.
    with nn.stop_gradient(lambda c: c.name.starts_with('gln')):
        # This class manages an iterated function (a scan)
        # This is an extremely computationally expensive scan.
        sampler_state, sampler_trajectory = sampler.sample_trajectory(
            encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
    # We calculate a GLN loss, which can only affects the subset of weights in weights_g.
    gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
                / sampler_iterations)
    iterative_code_message = sampler_state.code_message

    # Code loss ------------------------------------------------------------------------------------
    # The code loss trains the code and selection links to produce a code message that predicts the
    # code message that we inferred by iteration.
    # This is the same code_message function as above, but uses weights_c.
    with nn.stop_gradient(lambda c: c.name.starts_with('code')):
        c_code_message = encoding.code_message(observation, rng=code_rng,
                                               use_code_signal_noise=True)
    # When this loss is minimized only the weights that are not marked stop-gradient in weights_c
    # are adjusted.  Cotangents are also blocked from poisoning the scan by applying stop_gradient
    # to its outputs.
    code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
                                            - c_code_message.log_presence))
    code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
                                         - c_code_message.code_value))
    code_loss = code_presence_loss + code_value_loss

    # Snipped a lot of code here that uses weights_e and produces output primals.

    return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss)

Cristian suggested a lifting transformation like those found in flax.core.lift. I'm still learning how these work, so I can't yet sketch what this might look like.

Possible side benefits

Besides applying stop-gradient, this kind of system may be able to do other things with parameters such as:

  • marking parameters as constant within a block, and raising if any computation tries to change them,
  • temporarily replacing parameters with values from another variable within a computation, or
  • replacing parameter cotangents with values from another variable or another parameter cotangent.

Of course, that's beyond this feature request, but I mention these ideas as something to keep in mind when considering solutions.

Conclusion

Am I missing an easy solution to my problem? If not, I will need to solve this problem in order to use Flax since this use of stop-gradient is integral to my research. Thanks for reading!

commented

So in the haiku version of the code you are solving this "outside" of Haiku by operating on the variables dict directly and making 4 copies. In Flax you could do something similair for example by using flax.travere_util.flatten_dict. If you want to do this inside a linen Module you could use nn.map_variables where the mapping is basically identity but with stop_gradient applied to some or all of the params (again you cam make your life easy here by using flatten_dict).

commented

Here's an sketch of what that would look like:

from flax import traverse_util

 def selective_stop_grad(variables):
      flat_vars = traverse_util.flatten_dict(variables)
      new_vars = {k: lax.stop_gradient(v) if some_filter_fn(k) else v for k, v in flat_vars.items()}
      return traverse_util.unflatten_dict(new_vars)


class MySGModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    MySGSubModule = nn.map_variables(MySubModule, "params", selective_stop_grad, init=True)
    return MySGSubModule(...)(x)

@jheek the map_variables solution is great, I like it a lot!

A HOWTO about freezing parameters using this strategy would be great.

@jheek

I've been trying to implement your solution, but I can't seem to get it working for me. Here's roughly what I have:

from __future__ import annotations

from collections.abc import Callable
from dataclasses import asdict
from typing import Any, Generic, TypeVar

import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey

T = TypeVar('T', bound=nn.Module)


class StopGradientModule(nn.Module, Generic[T]):
    filter_f: Callable[[tuple[str, ...]], bool]
    submodule_cls: Callable[..., T]

    def setup(self) -> None:
        self.submodule = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient)

    def __call__(self, module: T) -> T:
        return self.submodule(**asdict(module))

    def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
        flat_vars = traverse_util.flatten_dict(variables)  # type: ignore[no-untyped-call]
        new_vars = {k: stop_gradient(v)
                    if self.filter_f(k) else v
                    for k, v in flat_vars.items()}
        return traverse_util.unflatten_dict(new_vars)  # type: ignore[no-untyped-call]


class X(nn.Module):
    def setup(self) -> None:
        self.dense = nn.Dense(10)
        # stop_gradient_all is a copy of self whose parameters are identical, but whose parameter
        # cotangents are always zero.
        self.stop_gradient_all = StopGradientModule(lambda _: True, X)

    def f(self, x: Any) -> Any:
        return self.dense(x), self.stop_gradient_all(self).dense(x)


print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f))

gives

Traceback (most recent call last):
  File "/home/neil/src/cmm/a.py", line 44, in <module>
    print(X().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=X.f))
  File "/home/neil/src/cmm/a.py", line 41, in f
    return self.dense(x), self.stop_gradient_all(self).dense(x)
  File "/home/neil/src/cmm/a.py", line 37, in setup
    self.dense = nn.Dense(10)
ValueError: Duplicate use of scope name: "dense"

I realize that this is currently a recursive mess, and I'm exploring the simplest way of accomplishing what I'm trying to accomplish.

I'm still trying to get this working. Here's what I have now:

from __future__ import annotations

from collections.abc import Callable
from typing import Any, Generic, TypeVar

import flax.linen as nn
import jax.numpy as jnp
from flax import traverse_util
from flax.core.scope import FrozenVariableDict
from jax.lax import stop_gradient
from jax.random import PRNGKey
from tjax import print_generic

T = TypeVar('T', bound=nn.Module)


class StopGradientModule(nn.Module, Generic[T]):
    filter_f: Callable[[tuple[str, ...]], bool]
    submodule_cls: Callable[..., T]

    def setup(self) -> None:
        mapped_cls = nn.map_variables(self.submodule_cls, True, self._selective_stop_gradient,
                                      methods=['f'])
        self.submodule = mapped_cls()

    def f(self, x: Any) -> Any:
        print("Calling")
        return self.submodule.f(x)

    def _selective_stop_gradient(self, variables: FrozenVariableDict) -> dict[str, Any]:
        flat_vars = traverse_util.flatten_dict(variables)  # type: ignore[no-untyped-call]
        new_vars = {k: stop_gradient(v)
                    if self.filter_f(k) else v
                    for k, v in flat_vars.items()}
        return traverse_util.unflatten_dict(new_vars)  # type: ignore[no-untyped-call]

    def __call__(self):
        assert False


class X(nn.Module):
    def setup(self) -> None:
        self.dense = nn.Dense(3)

    def f(self, x: Any) -> Any:
        return self.dense(x)

    def __call__(self):
        assert False


class Y(nn.Module):
    def setup(self) -> None:
        self.x = X()
        # stop_gradient_all is a copy of x whose parameters are identical, but whose parameter
        # cotangents are always zero.
        self.stop_gradient_all = StopGradientModule(lambda _: True, X)

    def f(self, x: Any) -> Any:
        y = self.x.f(x)
        return y, self.stop_gradient_all.f(x)

    def __call__(self):
        assert False

(y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
print(y, y_prime)
print_generic(variables)

gives

Traceback (most recent call last):
  File "/home/neil/src/cmm/a.py", line 66, in <module>
    (y, y_prime), variables = Y().init_with_output({'params': PRNGKey(0)}, jnp.ones(3), method=Y.f)
  File "/home/neil/src/cmm/a.py", line 61, in f
    return y, self.stop_gradient_all.f(x)
  File "/home/neil/src/cmm/a.py", line 28, in f
    return self.submodule.f(x)
  File "/home/neil/src/cmm/a.py", line 46, in f
    return self.dense(x)
  File "/home/neil/src/flax/flax/linen/linear.py", line 177, in __call__
    kernel = self.param('kernel',
flax.errors.ScopeCollectionNotFound: Tried to access "kernel" from collection "params"" in "/stop_gradient_all/map_variables(submodule)/dense" but the collection is emtpy. (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeCollectionNotFound)
commented

Ah that's because map_variables makes collections immutable. You need to provide a function that maps back the variables on output or pass init=True such that during init the map_variables isn't called.

Can you try passing init=True to nn.map_variables? I think that should fix your issue. Actually I accidentally removed the init=True from the example I copied from the docstring (updated my original examle) :S

@jheek Thanks, that gets it to run, but it's still not reflecting a copy of x's parameters? It outputs:

[-1.5530705 -0.6934959  0.9631546] [ 0.246286    0.83799624 -0.91129684]
FrozenDict
    params=FrozenDict
        stop_gradient_all=FrozenDict
            submodule=FrozenDict
                dense=FrozenDict
                    bias=Jax Array (3,) float32
                            0.0000      0.0000      0.0000
                    kernel=Jax Array (3, 3) float32
                            0.3932      0.3981     -0.5165
                            0.1566     -0.0768     -0.2396
                           -0.3035      0.5167     -0.1552
        x=FrozenDict
            dense=FrozenDict
                bias=Jax Array (3,) float32
                        0.0000      0.0000      0.0000
                kernel=Jax Array (3, 3) float32
                       -0.0201     -0.6220      0.9425
                       -0.2652     -0.1386      0.8165
                       -1.2677      0.0672     -0.7958

So x and stop_gradient_all are different. Any idea how I can make it a mirror? I realize I nee to pass x somehow, but I'm still not sure how.