Difference in output between jitted and non-jitted call
rodrigodzf opened this issue · comments
Rodrigo Diaz commented
I have found that the results of the forward pass, differ considerably if the apply function is jitted or not.
System information
- Ubuntu 23.10 x86_64
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib
: flax 0.8.2, jax 0.4.25, jaxlib 0.4.25+cuda12.cudnn89 - Python version: v3.10.13
- GPU/TPU model and memory: NVIDIA GeForce RTX 3090
- CUDA version (if applicable): 12.4
Problem you have encountered:
When using the forward pass for a simple MLP the results are different with the jitted version.
What you expected to happen:
That the results are the same.
Logs, error messages, etc:
Steps to reproduce:
import flax.linen as nn
import jax.numpy as jnp
import jax
from typing import Sequence
class MLP(nn.Module):
hidden_channels: Sequence[int]
activation: nn.Module = nn.relu
@nn.compact
def __call__(self, x: jnp.ndarray):
for i, channels in enumerate(self.hidden_channels):
x = nn.Dense(features=channels)(x)
if i != len(self.hidden_channels) - 1:
x = self.activation(x)
return x
d_hidden = 64
d_input = 64
d_batch = 3
proj = MLP(
hidden_channels=[d_hidden] * 2,
activation=nn.selu,
)
proj_vars = proj.init(jax.random.PRNGKey(546543), jnp.ones((d_batch, d_input)))
x = jnp.ones((d_batch, d_input))
out = proj.apply(proj_vars, x)
jitted_out = jax.jit(proj.apply)(proj_vars, x)
diff = jnp.abs(out - jitted_out).mean()
print(f"diff={diff}")
this gives an error of 0.0005452320910990238 in my computer. I noticed the error changes depending on the number of layers (for example 0.00216 with 10 layers) and with the type of activation. Is this a bug, or is there something I'm missing?