Composing `functools.partial` and `jax.vmap`
SamDuffield opened this issue · comments
Description
I'm having issues with composing partial
and vmap
with specified in_axes
def f(a, x):
return a + x.sum()
x_all = jax.numpy.ones((3, 10))
fvmap = jax.vmap(f, in_axes=(None, 0))
pfvmap = partial(fvmap, x=x_all)
pfvmap(3.)
# ValueError: vmap in_axes must be an int, None, or a tuple of entries corresponding to the positional arguments passed to the function, but got len(in_axes)=2, len(args)=1
This works with a lambda
or of course calling jax.vmap(f, in_axes=(None, 0))(3., x_all)
directly.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.24.3
python: 3.11.4 (main, Jul 5 2023, 08:54:11) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Sams-MacBook-Pro.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')
Thanks for the question! This has nothing to do with partial
, you can get the same error by passing the keyword argument directly:
fvmap(3.0, x=x_all)
The issue is that in_axes
applies only to positional arguments, not to keyword arguments. By convention, all keyword arguments to vmap
are mapped along axis zero, so you could use that fact to solve your problem this way:
fvmap = jax.vmap(f, in_axes=None)
pfvmap = partial(fvmap, x=x_all) # Note: keyword argument implicitly has in_axes=0
pfvmap(3.)
Thank you for the quick reply!! That makes sense and indeed indicates a solution to my problem 😆