google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

shard map implemenation of pmap only works for the first signature

jheek opened this issue · comments

commented

Description

from jax.experimental.shard_map import pmap
from jax import numpy as jnp

def f(x):
  return x * x

f = pmap(f, axis_name='batch')
f(jnp.ones((8,)))
f(jnp.ones((8, 2))) # raises StoreException: Store occupied

Giving multiple input signatures fails on the shard map implemenation of pmap

System info (python version, jaxlib version, accelerator, etc.)

latest version of jax

This is not polished or fully tested yet. It's purely experimental right now.