Using `vmap` on the loss function change results
gduflo opened this issue · comments
Hello, I am not sure whether this should be written as a JAX or Flax issue.
Here the simplified description of the training steps in my context (see the example below for more detail):
- I apply the model on the whole input data
- I select some indices of the obtained predictions
- These predictions are used along with the corresponding label for the loss computation
The reason I do so is that I am working with a graph neural network (for a node regression). I then get a prediction for all nodes, but I split nodes into batches to compute the average loss/gradient among batches.
My loss function takes as an input the indices to be used for the computation. Instead of calling the function for each batch, I wanted to use vmap
on the loss function which I expect to be faster. It however seems that it leads to different results. Here is an example of the issue:
import flax.linen as nn
import jax
import jax.numpy as N
import jax.random as R
import jax.tree_util as T
import optax
key = R.PRNGKey(0)
rng1, rng2, rng3, rng4 = R.split(key, 4)
X = R.normal(rng1, (1000, 20)) # Input: <num_instances> x <num_features>
Y = R.normal(rng2, (1000, 1)) # Label: <num_instances> x 1
# Batches of indices: <num_batches> x <len_batches>
batches = R.randint(rng3, (10, 500), 0, 1000)
model = nn.Dense(1)
model_vars = model.init(rng4, N.ones((1, 20)))
# Method 1: with vmap on loss function
@jax.jit
def train_step1(variables, x, y, bs):
def loss_fn(v, i):
p = model.apply(v, x)
loss = N.mean(optax.sigmoid_binary_cross_entropy(p[i], y[i]))
return loss
v_loss_grad_fn = jax.vmap(jax.value_and_grad(loss_fn), in_axes=(None, 0), out_axes=0)
loss, grads = v_loss_grad_fn(variables, bs)
loss = N.mean(loss, axis=0)
grads = T.tree_map(lambda g: N.mean(g, axis=0), grads)
variables = T.tree_map(lambda v, g: v - 0.01*g, variables, grads)
return variables, loss
# Method 2: without vmap on loss function
@jax.jit
def train_step2(variables, x, y, bs):
def loss_fn(v, i):
p = model.apply(v, x)
loss = N.mean(optax.sigmoid_binary_cross_entropy(p[i], y[i]))
return loss
loss_grad_fn = jax.value_and_grad(loss_fn)
l_list = []
g_list = []
for i in range(10): # Number of batches
idx = bs[i]
loss, grads = loss_grad_fn(variables, idx)
l_list.append(loss)
g_list.append(grads)
loss = N.mean(N.stack(l_list, axis=0), axis=0)
grads = T.tree_map(lambda *g: N.mean(N.stack(g, axis=0), axis=0), *g_list)
variables = T.tree_map(lambda v, g: v - 0.01*g, variables, grads)
return variables, loss
# Warming up
train_step1(model_vars, N.ones((1, 20)), N.ones((1, 1)), N.asarray([[0]]))
train_step2(model_vars, N.ones((1, 20)), N.ones((1, 1)), N.asarray([[0]]))
v1 = model_vars
v2 = model_vars
# Comparing results before trainings
p1 = model.apply(v1, X)
p2 = model.apply(v2, X)
print("Same results before trainings:", N.all(p1 == p2))
# Training
for _ in range(1000):
v1, l1 = train_step1(v1, X, Y, batches)
v2, l2 = train_step2(v2, X, Y, batches)
# Comparing results after both trainings
p1 = model.apply(v1, X)
p2 = model.apply(v2, X)
print("Same results after trainings:", N.all(p1 == p2))
print("Same losses after trainings", l1 == l2)
System information
- OS Platform and Distribution: Windows 11
- Flax, jax, jaxlib versions: flax (0.8.0) / jax (0.4.28) / jaxlib (0.4.28)
- Python version: 3.11.5
Maybe I misunderstood how vmap
is working, but I think that both methods described above should have the same behaviour. So I don't understand why they both lead to different results. Are those due to some approximation during the computation (because losses are still identical at last iteration)?
Thank you for your help.