google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Control the parallelism on CPU

YouJiacheng opened this issue · comments

JAX abstract all CPU cores into a single device, and doesn't provide any API to control (in a coarse-grain manner) the parallelism on multiple CPU cores.
Given there are many problem which is inefficient to parallel inside a single problem instance, users may want to explore SPMD parallelism, i.e. use only one core per problem instance while parallel along the batch axis.
For instance, see #10180 (comment)

import torch

torch.set_num_threads(1)
torch.set_num_interop_threads(24)

@torch.jit.script
def batched_eigh(x: torch.Tensor):
    futs = [torch.jit._fork(torch.linalg.eigh, x[i]) for i in range(24)]
    return [torch.jit._wait(fut) for fut in futs]

https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
Since JAX has vmap, it can potentially have a more elegant API, i.e. not resort to inter-op/MPMD parallelism, and not produce a list.
Alternative, JAX can make the --xla_force_host_platform_device_count XLA flag more performant, then users can use pmap.