google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Flax much slower than pure Jax?

Binbose opened this issue · comments

Hey, for a project I am trying to code up a very simple MLP example, but I noticed that the implementation in flax is about 20 times slower than the pure jax implementation. What am I doing wrong here?

import time
import jax.numpy as np
from jax import random, jit, vmap, jacfwd
from jax.nn import sigmoid, softplus
import jax
from flax import linen as nn  
import numpy as np
from typing import Sequence

def MLP(layers):
    def init(rng_key):
        def init_layer(key, d_in, d_out):
            k1, k2 = random.split(key)
            W = random.normal(k1, (d_in, d_out))
            b = random.normal(k2, (d_out,))
            return W, b
        key, *keys = random.split(rng_key, len(layers))
        params = list(map(init_layer, keys, layers[:-1], layers[1:]))
        return params

    def apply(params, inputs):
        for W, b in params[:-1]:
            outputs = np.dot(inputs, W) + b
            inputs = sigmoid(outputs)
        W, b = params[-1]
        outputs = np.dot(inputs, W) + b
        return outputs
    return init, apply


class FlaxNet(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x_in):
        x = nn.Dense(self.features[0], use_bias=False)(x_in)
        x = sigmoid(x)

        for feat in self.features[1:-1]:
            x = nn.Dense(feat, use_bias=False)(x)
            x = sigmoid(x)
        x = nn.Dense(self.features[-1], use_bias=False)(x)

        return x


rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
D = np.pi

layers = [1, 64, 64, 64, 32, 4]
net_init, net_apply = MLP(layers)
params = net_init(random.PRNGKey(0))

inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
_ = net_apply(params, inputs)

inputs = jax.random.uniform(rng, minval=-D, maxval=D, shape=(128, 1))
t1 = time.time()
outputs = net_apply(params, inputs)
print('TIME JAX ', time.time()-t1)

#############################################################################

model = FlaxNet(features=[64, 64, 64, 32, 4])
params = model.init(rng, inputs)

_ = model.apply(params, inputs)
t1 = time.time()
outputs = model.apply(params, inputs)
print('TIME FLAX ', time.time()-t1)

Which produces the output:

TIME JAX  0.0033071041107177734
TIME FLAX  0.08791708946228027

I made a public colab that fixes your benchmark in a few ways:
https://colab.research.google.com/drive/13DYXLKmLCO3K2Pd0Z1x-hHjwaSmK7Ked?usp=sharing

Inlining the comments inside that colab here:

  • plain numpy was being used by accident in JAX NN code, this was fixed to use jax.numpy aka jnp
  • both JAX and esp. FLAX are designed around the assumption of using JIT (jax.jit) compilation to make fast code, the colab shows the speed of jitted code compared to non-jitted eagerly evaluated / interpreted code.
  • JAX dispatches ops / code asynchronously, if you want to do a micro-benchmark then you MUST use output.block_until_ready() to block until the results are ready on-device.
TIME JAX 0.005536317825317383
TIME FLAX 0.11943674087524414
TIME JAX JITTED 0.0005037784576416016
TIME FLAX JITTED 0.0024785995483398438

We see no material difference between JAX/FLAX jitted times - we see that FLAX takes considerably longer than JAX eagerly, but this is to be expected as we handle a lot of general background bookkeeping. FLAX was designed to be used with jit (or pmap, pjit, etc.)

Oh you are right, thank you!
Maybe it makes sense to add the jit compilation line to the minimal examples on the Github page so that other dummies like me don't overlook that.