google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to extract intermediate values from "underneath" `vmap`?

jatentaki opened this issue · comments

I have a model (a simplified version below) which uses vmap to apply MHA Siamese-style (with weight sharing), in order to run a transformer decoder on two sequences jointly. I would now like to extract the attention weights to visualize what the model is doing but it seems that vmap "silences" everything underneath and makes things like linen.Module.sow fail; not even the capture_intermediates approach seems to work, despite being advertised as a sledgehammer. I would like to ask for the best/easiest way to accomplish that goal.

I believe a way forward would be to fork linen.MultiHeadDotProductAttention and make it return attention weights, then chain-return them out of my CrossAttention, adjust vmap out_axes accordingly and use sow inside of ParallelDecoder.__call__ (this first location outside of vmap's scope). Is there any less brutal way to achieve that goal?

from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn

class CrossAttention(nn.Module):
    nhead: int
    dropout: float = 0.1
    norm_fn: nn.Module = nn.LayerNorm

    @partial(
        nn.vmap,
        in_axes=(0, 0, None),
        out_axes=0,
        variable_axes={'params': None},
        split_rngs={'params': False, 'dropout': True},
    )
    @nn.compact
    def __call__(
        self,
        queries: 'Q C',
        memory: 'M C',
        deterministic: bool,
    ) -> 'Q C':
        return nn.MultiHeadDotProductAttention(
            self.nhead,
            dropout_rate=self.dropout,
        )(
            self.norm_fn(name='mha-norm')(queries),
            memory,
            deterministic=deterministic,
        )

class ParallelDecoder(nn.Module):
    n_query: int = 8
    depth: int = 4

    @nn.compact
    def __call__(self, memory: '2 M C') -> '2 Q C':
        _2, M, C = memory.shape

        queries = self.param('q', nn.initializers.normal(stddev=1.), (2, self.n_query, C))

        for _ in range(self.depth):
            queries = CrossAttention(nhead=8)(queries, memory, True)

            # in a more complete example I would apply an MLP on queries
            # to enable mixing between the two items in the leading dimension

        return queries

## Example usage
model = ParallelDecoder()
input = np.random.randn(2, 16, 128)
state = model.init(jax.random.PRNGKey(42), input)
output, interms = model.apply(state, input, capture_intermediates=True, mutable=['intermediates'])