google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Composing `functools.partial` and `jax.vmap`

SamDuffield opened this issue · comments

Description

I'm having issues with composing partial and vmap with specified in_axes

def f(a, x):
    return a + x.sum()

x_all = jax.numpy.ones((3, 10))

fvmap = jax.vmap(f, in_axes=(None, 0))
pfvmap = partial(fvmap, x=x_all)
pfvmap(3.)
# ValueError: vmap in_axes must be an int, None, or a tuple of entries corresponding to the positional arguments passed to the function, but got len(in_axes)=2, len(args)=1

This works with a lambda or of course calling jax.vmap(f, in_axes=(None, 0))(3., x_all) directly.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.24.3
python: 3.11.4 (main, Jul  5 2023, 08:54:11) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Sams-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')

Thanks for the question! This has nothing to do with partial, you can get the same error by passing the keyword argument directly:

fvmap(3.0, x=x_all)

The issue is that in_axes applies only to positional arguments, not to keyword arguments. By convention, all keyword arguments to vmap are mapped along axis zero, so you could use that fact to solve your problem this way:

fvmap = jax.vmap(f, in_axes=None)
pfvmap = partial(fvmap, x=x_all)  # Note: keyword argument implicitly has in_axes=0 
pfvmap(3.)

Thank you for the quick reply!! That makes sense and indeed indicates a solution to my problem 😆