Control the parallelism on CPU
YouJiacheng opened this issue · comments
JAX abstract all CPU cores into a single device, and doesn't provide any API to control (in a coarse-grain manner) the parallelism on multiple CPU cores.
Given there are many problem which is inefficient to parallel inside a single problem instance, users may want to explore SPMD parallelism, i.e. use only one core per problem instance while parallel along the batch axis.
For instance, see #10180 (comment)
import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(24)
@torch.jit.script
def batched_eigh(x: torch.Tensor):
futs = [torch.jit._fork(torch.linalg.eigh, x[i]) for i in range(24)]
return [torch.jit._wait(fut) for fut in futs]
https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html
Since JAX has vmap
, it can potentially have a more elegant API, i.e. not resort to inter-op/MPMD parallelism, and not produce a list
.
Alternative, JAX can make the --xla_force_host_platform_device_count
XLA flag more performant, then users can use pmap
.