shard map implemenation of pmap only works for the first signature
jheek opened this issue · comments
jheek commented
Description
from jax.experimental.shard_map import pmap
from jax import numpy as jnp
def f(x):
return x * x
f = pmap(f, axis_name='batch')
f(jnp.ones((8,)))
f(jnp.ones((8, 2))) # raises StoreException: Store occupied
Giving multiple input signatures fails on the shard map implemenation of pmap
System info (python version, jaxlib version, accelerator, etc.)
latest version of jax
Yash Katariya commented
This is not polished or fully tested yet. It's purely experimental right now.