google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

vmap is much slower than manual vectorization

bbfrog opened this issue · comments

Please see below simple example and I tested the performance in V100 GPU: the speed of manual vectorization is 3X of using vmap to do auto vectorization, is it expected? Thanks!

import ast
import jax.numpy as jnp
from jax import jit, random, vmap
import time

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='Jax batching performance test')
    parser.add_argument("--vmap", default=False, type=ast.literal_eval, help="Whether using vmap for vectorization")
    parser.add_argument("--jit", default=False, type=ast.literal_eval, help="Whether do jit for function")
    parser.add_argument('--T', default=1000, type=int, help='Run times')
    args = parser.parse_args()
    return args

args = parse_args()

def predict(w, x):
    #print(x.ndim)
    outputs = jnp.matmul(x, w)
    return outputs

func = jit(predict) if args.jit else predict

key = random.PRNGKey(0)
x = random.normal(key, (256,512))
w = random.normal(key, (512,512))

start = time.time()
for _ in range(args.T):
    if args.vmap:
        vmap(func, in_axes=(None, 0))(w, x)
    else:
        func(w, x)
stop = time.time()

print("avg time: %f us" % ((stop - start) * 1e6 / args.T))

In general, vmap should not be slower than manual vectorization.

But in this case I'm not sure if we're measuring what we want. We need to use block_until_ready(), otherwise we're not actually timing any work.

What are the times that you measure?

In general, vmap should not be slower than manual vectorization.

But in this case I'm not sure if we're measuring what we want. We need to use block_until_ready(), otherwise we're not actually timing any work.

What are the times that you measure?

Thanks mattjj. I added the block_until_ready as below (sorry that I just started to learn jax), do i use it correctly? After adding this, the timing are almost the same and the vmap is still much slower than manual vectorization. Would you please check?

start = time.time()
for _ in range(args.T):
    if args.vmap:
        vmap(func, in_axes=(None, 0))(w, x).block_until_ready()
    else:
        func(w, x).block_until_ready()
stop = time.time()

I see a relative slowdown on CPU too, if there's no jit or when nesting vmap-of-jit:

$ cat test.py
from functools import partial
from timeit import timeit
from jax import vmap, jit, random, numpy as jnp

n, d = 512, 64
a = random.normal(random.PRNGKey(0), (n, d))
b = random.normal(random.PRNGKey(0), (d, d))

mm = jnp.matmul
v = partial(vmap, in_axes=(0, None))

for f in [mm, v(mm), jit(mm), v(jit(mm)), jit(v(mm))]:
  run = lambda: f(a, b).block_until_ready()
  t = timeit(run, setup=run, number=1000)
  print(f'{t:.3f}')
$ python test.py
0.308
0.545
0.150
0.344
0.149

As a workaround: it seems that using jit outside vmap rather than the other way around recovers the expected performance.

I see a relative slowdown on CPU too, if there's no jit or when nesting vmap-of-jit:

$ cat test.py
from functools import partial
from timeit import timeit
from jax import vmap, jit, random, numpy as jnp

n, d = 512, 64
a = random.normal(random.PRNGKey(0), (n, d))
b = random.normal(random.PRNGKey(0), (d, d))

mm = jnp.matmul
v = partial(vmap, in_axes=(0, None))

for f in [mm, v(mm), jit(mm), v(jit(mm)), jit(v(mm))]:
  run = lambda: f(a, b).block_until_ready()
  t = timeit(run, setup=run, number=1000)
  print(f'{t:.3f}')
$ python test.py
0.308
0.545
0.150
0.344
0.149

As a workaround: it seems that using jit outside vmap rather than the other way around recovers the expected performance.

Thanks froystig for your refined test:). I run your code in V100 GPU and got below number and the workaround should still work.
0.275
0.913
0.070
0.530
0.069

Close this issue as jit(vmap) can match the manually vectorization

In general, putting jit on the outside is always a good idea!

Still, I'm curious where the time was being spent here: was it just higher overheads from having vmap on the outside (plausible, expected), or something else (would be weird...)? Maybe we'll look at a profile.

Thanks mattjj very much, it will be great to share us the profile results or (and) conclusion. And please feel free to reopen this issue to track. Thanks!

Yeah, it's basically overheads.

Here's a vmap-of-jit profile on GPU (using the unreleased jaxlib==0.1.65, on a colab instance so probably extra high overheads):

image

Here's jit-of-vmap:

image

With jit-of-vmap, we hit a very fast path: we the call to the jitted function jumps straight into C++, and that immediately calls into the JAX runtime (called PjRt) to enqueue a volta sgemm kernel.

The vmap-of-jit path has a lot more going on. First we do Python work in the vmap wrapper (e.g. batching.py's batchfun). Then it hits the C++ jit dispatch path, but that has to bail out (by calling api.py's cache_miss) because the C++ path can't handle vmap tracers as arguments. That leads to a bunch more Python work (call_bind, process_call, call_bind, ...) until we finally execute the XLA computation from Python _execute_compiled and get into PjRt. Then there's a similar amount of Python work returning from the computation.

We plan to improve the C++ jit dispatch path to handle tracer inputs, which would cut out some of that overhead, but there'd still be a decent amount of the vmap overhead left. So it'll always be best to put jit on the outside if you can!

(Notice that in both these cases the GPU is only doing useful work for a fraction of the time. That's just because this is a really small computation.)

Hey folks, tacking on some interesting performance issues I'm seeing with vmap here as well since the issue title fit nicely.
I was trying out the auto vectorization to move some of our hand batched code. The intent of the code block below is to multiply k-size (here 3) subsets of the n columns of a dataset and sum them to see what proportion of those datasets are all ones in those k columns.

from jax import numpy as np, vmap, jit, random
import itertools

key = random.PRNGKey(0)

# dataset
D = random.bernoulli(key, 0.5, shape=[60000, 100])
# queries are the column indices we want to multiply
queries = random.permutation(key, np.array([comb for comb in itertools.combinations(np.arange(100), 3)]))[:1000]

# a single query: multiply the columns we care about and see what proportion of them are ones
def _single_query(D, query):
    return np.sum(np.prod(D[:, query], axis=1))/D.shape[0]

# generate a function that can compute result of some pre-determined subset of queries on the dataset, vmap over queries
def auto_batched_preserve_subset_statistic(queries):
    @jit
    def compute_statistic(D):
        return jit(vmap(_single_query, (None, 0)))(D, queries)
    return compute_statistic

# generate a function that can compute result of some pre-determined subset of queries on the dataset, hand vectorize
# over the queries
def hand_batched_preserve_subset_statistic(queries):
    @jit
    def compute_statistic(D):
        temp = np.array_split(queries, 10)
        return np.concatenate([
            np.prod(D[:, q], 2).sum(0) for q in temp
        ]) / D.shape[0]
    return compute_statistic

# hand batched statistic function
hb_compute_statistic = hand_batched_preserve_subset_statistic(queries)
# vmap/auto batched statistic function
ab_compute_statistic = auto_batched_preserve_subset_statistic(queries)

And to find how long it took:

>>> %timeit ab_compute_statistic(D).block_until_ready()
2.39 ms ± 5.25 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Versus...

>>> %timeit hb_compute_statistic(D).block_until_ready()
660 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

If I instead increased queries by a factor of 10, instead of the autobatching getting closer, both slowed down by the same factor (hb_compute_statistic -> 6 ms and ab_compute_statistic -> 23 ms).

I'm not entirely sure what's going wrong here (if anything is). I stuck to making sure I jit on vmap and not the other way.