`LevenbergMarquardt` implementation does not accept PyTree parameters
Joshuaalbert opened this issue · comments
Joshua G Albert commented
Description
The LevenbergMarquardt
implementation does not accept PyTree parameters, giving TypeError: primal and tangent arguments to jax.jvp must have the same tree structure
at levenberg_marquardt.py, line 534.
MVCE
from dataclasses import dataclass
from typing import Literal, NamedTuple, Tuple
import jaxopt
from jax import numpy as jnp
class CalibrationParams(NamedTuple):
gains_real: jnp.ndarray # [source, time, ant, chan, 2, 2]
gains_imag: jnp.ndarray # [source, time, ant, chan, 2, 2]
class CalibrationData(NamedTuple):
gains_real: jnp.ndarray # [source, time, ant, chan, 2, 2]
gains_imag: jnp.ndarray # [source, time, ant, chan, 2, 2]
@dataclass(eq=False)
class Calibration:
convention: Literal['fourier', 'casa'] = 'casa'
dtype: jnp.dtype = jnp.complex64
chunksize: int = 1
unroll: int = 1
def _residual_fun(self, params: CalibrationParams, data: CalibrationData) -> jnp.ndarray:
residuals = jnp.concatenate([
(params.gains_real - data.gains_real).ravel(),
(params.gains_imag - data.gains_imag).ravel()
])
return residuals
@property
def float_dtype(self):
# Given self.dtype is complex, find float dtype
return jnp.real(jnp.zeros((), dtype=self.dtype)).dtype
def get_init_params(self, shape) -> CalibrationParams:
"""
Get initial parameters.
Args:
shape: shape of gains_real and gains_imag
Returns:
initial parameters
"""
return CalibrationParams(
gains_real=jnp.ones(shape, self.float_dtype),
gains_imag=jnp.zeros(shape, self.float_dtype)
)
def solve(self, init_params: CalibrationParams, data: CalibrationData) -> Tuple[CalibrationParams, jaxopt.OptStep]:
solver = jaxopt.LevenbergMarquardt(
residual_fun=self._residual_fun,
maxiter=100,
jit=True,
unroll=False,
materialize_jac=False,
geodesic=False,
implicit_diff=False
)
opt_result = solver.run(init_params=init_params, data=data)
params = opt_result.params
return params, opt_result
if __name__ == '__main__':
calibration = Calibration()
shape = (10, 100, 100, 100, 2, 2)
init_params = calibration.get_init_params(shape)
data = CalibrationData(
gains_real=jnp.ones(shape, calibration.float_dtype),
gains_imag=jnp.zeros(shape, calibration.float_dtype)
)
params, opt_results = calibration.solve(init_params=init_params, data=data)
print(params)
print(opt_results)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 359, in run
return run(init_params, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/base.py", line 301, in _run
state = self.init_state(init_params, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 216, in init_state
jtj_diag = self._jtj_diag_op(init_params, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 535, in _jtj_diag_op
return jax.vmap(diag_op)(jnp.eye(len(params))).T
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 534, in <lambda>
diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs)
File "/home/albert/miniconda3/envs/dsa_py/lib/python3.10/site-packages/jaxopt/_src/levenberg_marquardt.py", line 528, in _jtj_op
_, jvp_val = jax.jvp(fun_with_args, (params,), (vec,))
TypeError: primal and tangent arguments to jax.jvp must have the same tree structure; primals have tree structure PyTreeDef((CustomNode(namedtuple[CalibrationParams], [*, *]),)) whereas tangents have tree structure PyTreeDef((*,)).