Significant performance difference of NNX relative to equinox
jlperla opened this issue · comments
I decided to try the nnx vs. equinox for performance and am seeing significant differences (3'ish times slower for nnx). Could be that I wrote a poor MLP implementation or made a collosal profiling mistake.
My apologies if the benchmarking itself is flaws or the MLP implementation is incorrect in some way. But if it is the later, it shows that a documented MLP implementa`ton for NNX to copy/paste might help.
System information
- latest released NNX
- Tried on laptop (macos with CPU) as well as colab with an accelerator
Problem you have encountered:
The performance of my test suite on my CPU is
Time taken NNX: 0.00055 seconds
Time taken EQX: 0.00019 seconds
And on the colab T4 GPU runtime
Time taken NNX: 0.00220 seconds
Time taken EQX: 0.00066 seconds
Steps to reproduce:
Test Suite:
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
rngs = nnx.Rngs(0)
@nnx.jit
def my_test(batch, model):
@nnx.jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
n_in = 64
n_out = 1
depth = 3
width = 128
activation = nnx.relu
model = MLP(n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs)
my_batch = jax.random.normal(rngs(), (20, n_in))
# Time NNX
out = my_test(my_batch, model).block_until_ready()
start_time = time.time()
out = my_test(my_batch, model).block_until_ready()
end_time = time.time()
print(f"Time taken NNX: {end_time - start_time:.5f} seconds")
@eqx.filter_jit
def my_test_eqx(batch, model):
@eqx.filter_jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
equinox_model = eqx.nn.MLP(n_in, n_out, width_size=width, depth=depth, activation=activation, key=rngs())
# Time Equinox
out = my_test_eqx(my_batch, equinox_model)
start_time = time.time()
out = my_test_eqx(my_batch, equinox_model).block_until_ready()
end_time = time.time()
print(f"Time taken EQX: {end_time - start_time:.5f} seconds")
On colab you need to do ! pip install equinox
To add to this: the performance of linen seems to be similar to NNX. Although I am even less clear how to profile there. Here was my implementation
import typing as tp
import jax
import jax.numpy as jnp
import flax.linen as linen
from flax.core import freeze, unfreeze
from flax.training import train_state
from flax.typing import Dtype, PrecisionLike
import optax
import time
class MLPLinen(linen.Module):
in_features: int
out_features: int
width: int
depth: int
activation: tp.Callable
use_bias: bool = True
use_final_bias: bool = True
final_activation: tp.Optional[tp.Callable] = None
dtype: tp.Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
@linen.compact
def __call__(self, x: jax.Array) -> jax.Array:
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
for _ in range(self.depth - 1):
x = self.activation(x)
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
x = linen.Dense(
self.out_features,
use_bias=self.use_final_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
if self.final_activation is not None:
x = self.final_activation(x)
return x
def create_train_state_linen(rng, model, learning_rate):
params = model.init(rng, jnp.ones([1, model.in_features]))['params']
tx = optax.adam(learning_rate)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def compute_loss_linen(params, batch, model_apply_fn):
logits = model_apply_fn({'params': params}, batch)
loss = jnp.mean(logits)
return loss
@jax.jit
def train_step_linen(state, batch):
grad_fn = jax.value_and_grad(compute_loss_linen)
loss, grads = grad_fn(state.params, batch, state.apply_fn)
state = state.apply_gradients(grads=grads)
return state, loss
if __name__ == "__main__":
rng = jax.random.PRNGKey(0)
n_in = 64
n_out = 1
depth = 3
width = 128
activation = linen.relu
model = MLPLinen(n_in, n_out, width=width, depth=depth, activation=activation)
state = create_train_state_linen(rng, model, learning_rate=0.001)
my_batch = jax.random.normal(rng, (20, n_in))
# Time Linen
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
start_time = time.time()
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
end_time = time.time()
print(f"Time taken Linen: {end_time - start_time:.5f} seconds")
Hey @jlperla, can you use timeit
or similar to report the results? A single step, specially the first one that involves compilation is not very meaningful.
That said, this is what I would expect:
- Linen should be the fastest one as the
params
structure is a simple dictionary and you are using regularjax.jit
. JAX has optimized code to traverse dicts. - Both NNX and Equinox suffer due to python being slow, see Low-overhead training loops.
@jlperla Maybe useful to note here, For small MLPs you are likely will be in the overhead regime. To overcome the framework overhead (in nnx or equinox) you may use nnx.{split,merge}
or equinox.{parition,combine}
pattern with non lifted jax transforms.
@ASEM000 correct. Ideally we document how to overcome the overhead problem in the near future.
@cgarciae @ASEM000 Absolutely. But the issue is comparing the relative overhead of NNX vs. Equinox for the same pattern? I find the timeit hard to use, but made sure things were compiled and retried multiple times?
Why the equinox code would be so much faster than NNX (which seems roughly similar to flax linen)? What is the overhead that would be so much more significant there, using the same coding pattern? If you look at my code I am isolating a single "value and grad" call, no optimizer overhead or training loop. And precompiling it before timing.
So either
0) It looks like my two sets of code are doing the same thing, but they really aren't.
- I implemented the MLP poorly in NNX, which is VERY likely, and the one in equinox is done correctly.
- There is some sort of overhead in the filtering process which is significantly more expensive in NNX vs. equinox. Maybe a manual split and combine (which can be done in both) would make it disappear
@jlperla I do imagine the NNX overhead being greater than the Equinox overhead as we do more bookkeeping and its not optimized. If performance is critical you should just train using split
/ merge
. Here is a modified version comparing both NNX and Equinox using low-overhead versions:
from functools import partial
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == '__main__':
rngs = nnx.Rngs(0)
@jax.jit
def my_test(batch, graphdef, state):
model = nnx.merge(graphdef, state)
def loss_closure(model):
return jnp.mean(jax.vmap(model)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
n_in = 64
n_out = 1
depth = 3
width = 128
activation = nnx.relu
model = MLP(
n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs
)
my_batch = jax.random.normal(rngs(), (20, n_in))
graphdef, state = nnx.split(model)
# Time NNX
out = my_test(my_batch, graphdef, state).block_until_ready()
start_time = time.time()
out = my_test(my_batch, graphdef, state).block_until_ready()
end_time = time.time()
print(f'Time taken NNX: {end_time - start_time:.5f} seconds')
# -----------
# Equinox
# -----------
@eqx.filter_jit
def my_test_eqx(batch, treedef, leaves):
model = jax.tree.unflatten(treedef, leaves)
@eqx.filter_jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
equinox_model = eqx.nn.MLP(
n_in,
n_out,
width_size=width,
depth=depth,
activation=activation,
key=rngs(),
)
leaves, treedef = jax.tree.flatten(equinox_model)
# Time Equinox
out = my_test_eqx(my_batch, treedef, leaves)
start_time = time.time()
out = my_test_eqx(my_batch, treedef, leaves).block_until_ready()
end_time = time.time()
print(f'Time taken EQX: {end_time - start_time:.5f} seconds')
Output on my M1:
Time taken NNX: 0.00007 seconds
Time taken EQX: 0.00019 seconds
This version might still be suboptimal for Equinox because of the use of eqx.filter_jit
instead of jax.jit
.
We will add a guide on NNX transforms explaining how they work under the hood in the future.
Some documentation would be very useful, also ran into this when profiling nnx vs. linen.
Linen is already low-overhead, I'll try to add it to the benchmark.
@cgarciaethanks, this helps a lot. I don't feel like you need to compare to equinox in your docs. My main concern was that it seemed to be 3x slower for the same task. But if you are doing more bookkeeping, then it isn't really the same task.
and just to confirm: my MLP implementation is as high performance as possible? If so, maybe that is helpful to have in the docs for people to adapt.
I believe so. @jlperla do you want to contribute it as an NNX example?
While we don't let people directly import examples we can point to it on the documentation and it could serve as a reference implementation that people can easily copy into their codebase.
@cgarciae OK, here is my attempt comparing your code. The summary is now:
- equinox partition-combine vs. NNX split-merge has roughly the same performance
- NNX filtering is roughly the same speed as linen (but that is a little unfair to linen, if you look at my code, since I couldn't decouple the gradient calculation from setting up an optimizer).
- equinox filtering without manual splitting is roughly 2-3 times faster than NNX splitting.
- I didn't do the equivalent to https://docs.kidger.site/equinox/tricks/#low-overhead-training-loops which does a full flatten to really speed things up, which might make sense in library code even if it is too ugly for most user code. You might want to consider an example in your docs which does the equivalent full flattening, but I think lower priority.
At this point I am now convinced that there is nothing fundamentally different between NNX and Equinox that holds back performance, even if there is probably some performance tweaks that may occur in the future. I feel like you could close out this issue and I could prepare a simple example for the docs (without the performance comparisons) if you are willing? Maybe a simple nonlinear regression with an example MLP?
If you are interested, here was my code, which runs on my system as
Time taken NNX: 0.000358 seconds
Time taken NNX Split: 0.000054 seconds
Time taken EQX: 0.000130 seconds
Time taken EQX Split: 0.000051 seconds
Time taken Linen: 0.000394 seconds
import typing as tp
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx.nnx import rnglib
from flax.typing import Dtype, PrecisionLike
import equinox as eqx
import time
import flax.linen as linen
from flax.core import freeze, unfreeze
from flax.training import train_state
from functools import partial
import optax
class MLPLinen(linen.Module):
in_features: int
out_features: int
width: int
depth: int
activation: tp.Callable
use_bias: bool = True
use_final_bias: bool = True
final_activation: tp.Optional[tp.Callable] = None
dtype: tp.Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
precision: PrecisionLike = None
@linen.compact
def __call__(self, x: jax.Array) -> jax.Array:
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
for _ in range(self.depth - 1):
x = self.activation(x)
x = linen.Dense(
self.width,
use_bias=self.use_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
x = linen.Dense(
self.out_features,
use_bias=self.use_final_bias,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
)(x)
if self.final_activation is not None:
x = self.final_activation(x)
return x
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
*,
width: int,
depth: int,
activation: tp.Callable,
rngs: rnglib.Rngs,
use_bias: bool = True,
use_final_bias: bool = True,
final_activation: tp.Optional[tp.Callable] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
):
self.in_features = in_features
self.out_features = out_features
self.width = width
self.depth = depth
self.use_bias = use_bias
self.use_final_bias = use_final_bias
self.activation = activation
self.final_activation = final_activation
assert depth > 0 # skipping specialization of no hidden layers
self.layers = []
self.layers.append(
nnx.Linear(
in_features,
width,
use_bias=use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
for i in range(self.depth - 1):
self.layers.append(
nnx.Linear(
width,
width,
use_bias=self.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
self.layers.append(self.activation)
self.layers.append(
nnx.Linear(
width,
out_features,
use_bias=self.use_final_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
)
if self.final_activation is not None:
self.layers.append(self.final_activation)
def __call__(self, x: jax.Array) -> jax.Array:
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
rngs = nnx.Rngs(0)
@nnx.jit
def my_test(batch, model):
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
@jax.jit
def my_test_split(batch, graphdef, state):
model = nnx.merge(graphdef, state)
def loss_closure(model):
return jnp.mean(jax.vmap(model)(batch))
loss_val, loss_grad = nnx.value_and_grad(loss_closure)(model)
return loss_val
n_in = 64
n_out = 1
depth = 1
width = 128
activation = nnx.relu
model = MLP(n_in, n_out, width=width, depth=depth, activation=activation, rngs=rngs)
my_batch = jax.random.normal(rngs(), (20, n_in))
# Time NNX
out = my_test(my_batch, model).block_until_ready()
start_time = time.time()
out = my_test(my_batch, model).block_until_ready()
end_time = time.time()
print(f"Time taken NNX: {end_time - start_time:.6f} seconds")
graphdef, state = nnx.split(model)
out = my_test_split(my_batch, graphdef, state).block_until_ready()
start_time = time.time()
out = my_test_split(my_batch, graphdef, state).block_until_ready()
end_time = time.time()
print(f"Time taken NNX Split: {end_time - start_time:.6f} seconds")
@eqx.filter_jit
def my_test_eqx(batch, model):
@eqx.filter_jit
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
@partial(jax.jit, static_argnums=2)
def my_test_eqx_split(batch, params, static):
model = eqx.combine(params, static)
def loss_closure(f):
return jnp.mean(jax.vmap(f)(batch))
loss_val, loss_grad = eqx.filter_value_and_grad(loss_closure)(model)
return loss_val
equinox_model = eqx.nn.MLP(n_in, n_out, width_size=width, depth=depth, activation=activation, key=rngs())
# Time Equinox
out = my_test_eqx(my_batch, equinox_model)
start_time = time.time()
out = my_test_eqx(my_batch, equinox_model).block_until_ready()
end_time = time.time()
print(f"Time taken EQX: {end_time - start_time:.6f} seconds")
params, static = eqx.partition(equinox_model, eqx.is_array)
out = my_test_eqx_split(my_batch, params, static)
start_time = time.time()
out = my_test_eqx_split(my_batch, params, static).block_until_ready()
end_time = time.time()
print(f"Time taken EQX Split: {end_time - start_time:.6f} seconds")
# Time Linen
def create_train_state_linen(rng, model, learning_rate):
params = model.init(rng, jnp.ones([1, model.in_features]))['params']
tx = optax.adam(learning_rate)
return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
def compute_loss_linen(params, batch, model_apply_fn):
logits = model_apply_fn({'params': params}, batch)
loss = jnp.mean(logits)
return loss
@jax.jit
def train_step_linen(state, batch):
grad_fn = jax.value_and_grad(compute_loss_linen)
loss, grads = grad_fn(state.params, batch, state.apply_fn)
state = state.apply_gradients(grads=grads)
return state, loss
model = MLPLinen(n_in, n_out, width=width, depth=depth, activation=linen.relu)
state = create_train_state_linen(rngs(), model, learning_rate=0.001)
my_batch = jax.random.normal(rngs(), (20, n_in))
# Time Linen
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
start_time = time.time()
state, loss_val = train_step_linen(state, my_batch)
jax.block_until_ready(loss_val)
end_time = time.time()
print(f"Time taken Linen: {end_time - start_time:.6f} seconds")
Sorry to bug you @cgarciae but one more quetsion on this. In the standard training loop, as in https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/mnist_tutorial.html#training-step the optimizer updates the model implicitly with optimizer.update(grads)
If I need to repeatedly split the internal model before calling the optimized step, what is the best pattern to do so? Do I adapt that code to do something like
@jax.jit
def train_step(state, graphdef, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
"""Train for a single step."""
model = nnx.merge(graphdef, state)
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
optimizer.update(grads)
Then update my training loop to something like
for step, batch in enumerate(train_ds.as_numpy_iterator()):
state, graphdef = nnx.split(optimizer.model)
train_step(state, graph_def, optimizer, metrics, batch)
Is that what you had in mind? If so, maybe I could get someone to put in a PR to extend that MNIST tutorial along those lines as proactice?
To be honest, I don't quite understand why any of this speeds things up so much. I would have thought flattening things out and splitting at the boundary of the loss_fun
would have been where things really made a difference... but JAX performance intuition is not for mortals.
Hey @jlperla ! The best way to do this is split
all the objects together in a tuple and unpack them after merge
, then do another split
at the end.
graphdef, state = nnx.split((model, optimizer, metrics))
...
@nnx.jit
def train_step(graphdef, state, batch):
"""Train for a single step."""
model, optimizer, metrics = nnx.merge(graphdef, state)
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(model, batch)
metrics.update(loss=loss, logits=logits, labels=batch['label'])
optimizer.update(grads)
_, state = nnx.split((model, optimizer, metrics))
return loss, state
By splitting them together their mutual references are preserved.
If so, maybe I could get someone to put in a PR to extend that MNIST tutorial along those lines as proactice?
I think we should have a "performance" guide, or at least add it to a section in some other guide but its indeed an important topic.
Just to point out here that during nnx models training, the overall GPU usage do not cross above 20%. Will splitting into graphdef and state improve the performance?