google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Difference in output between jitted and non-jitted call

rodrigodzf opened this issue · comments

I have found that the results of the forward pass, differ considerably if the apply function is jitted or not.

System information

  • Ubuntu 23.10 x86_64
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax 0.8.2, jax 0.4.25, jaxlib 0.4.25+cuda12.cudnn89
  • Python version: v3.10.13
  • GPU/TPU model and memory: NVIDIA GeForce RTX 3090
  • CUDA version (if applicable): 12.4

Problem you have encountered:

When using the forward pass for a simple MLP the results are different with the jitted version.

What you expected to happen:

That the results are the same.

Logs, error messages, etc:

Steps to reproduce:

import flax.linen as nn
import jax.numpy as jnp
import jax
from typing import Sequence

class MLP(nn.Module):
    hidden_channels: Sequence[int]
    activation: nn.Module = nn.relu

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        for i, channels in enumerate(self.hidden_channels):
            x = nn.Dense(features=channels)(x)
            if i != len(self.hidden_channels) - 1:
                x = self.activation(x)
        return x
    
d_hidden = 64
d_input = 64
d_batch = 3
proj = MLP(
    hidden_channels=[d_hidden] * 2,
    activation=nn.selu,
)

proj_vars = proj.init(jax.random.PRNGKey(546543), jnp.ones((d_batch, d_input)))
x = jnp.ones((d_batch, d_input)) 

out = proj.apply(proj_vars, x)
jitted_out = jax.jit(proj.apply)(proj_vars, x)
diff = jnp.abs(out - jitted_out).mean()
print(f"diff={diff}")

this gives an error of 0.0005452320910990238 in my computer. I noticed the error changes depending on the number of layers (for example 0.00216 with 10 layers) and with the type of activation. Is this a bug, or is there something I'm missing?