How to extract intermediate values from "underneath" `vmap`?
jatentaki opened this issue · comments
I have a model (a simplified version below) which uses vmap
to apply MHA Siamese-style (with weight sharing), in order to run a transformer decoder on two sequences jointly. I would now like to extract the attention weights to visualize what the model is doing but it seems that vmap
"silences" everything underneath and makes things like linen.Module.sow
fail; not even the capture_intermediates
approach seems to work, despite being advertised as a sledgehammer. I would like to ask for the best/easiest way to accomplish that goal.
I believe a way forward would be to fork linen.MultiHeadDotProductAttention
and make it return attention weights, then chain-return them out of my CrossAttention
, adjust vmap
out_axes
accordingly and use sow
inside of ParallelDecoder.__call__
(this first location outside of vmap
's scope). Is there any less brutal way to achieve that goal?
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
class CrossAttention(nn.Module):
nhead: int
dropout: float = 0.1
norm_fn: nn.Module = nn.LayerNorm
@partial(
nn.vmap,
in_axes=(0, 0, None),
out_axes=0,
variable_axes={'params': None},
split_rngs={'params': False, 'dropout': True},
)
@nn.compact
def __call__(
self,
queries: 'Q C',
memory: 'M C',
deterministic: bool,
) -> 'Q C':
return nn.MultiHeadDotProductAttention(
self.nhead,
dropout_rate=self.dropout,
)(
self.norm_fn(name='mha-norm')(queries),
memory,
deterministic=deterministic,
)
class ParallelDecoder(nn.Module):
n_query: int = 8
depth: int = 4
@nn.compact
def __call__(self, memory: '2 M C') -> '2 Q C':
_2, M, C = memory.shape
queries = self.param('q', nn.initializers.normal(stddev=1.), (2, self.n_query, C))
for _ in range(self.depth):
queries = CrossAttention(nhead=8)(queries, memory, True)
# in a more complete example I would apply an MLP on queries
# to enable mixing between the two items in the leading dimension
return queries
## Example usage
model = ParallelDecoder()
input = np.random.randn(2, 16, 128)
state = model.init(jax.random.PRNGKey(42), input)
output, interms = model.apply(state, input, capture_intermediates=True, mutable=['intermediates'])