google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

Home Page:https://jaxopt.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

LevenbergMarquardt do not seems to work with non-flat input.

bolducke opened this issue · comments

GaussNewton is working as intended with Pytree. I would expect the same for LM. Instead, I had to flatten the array to make it properly works.

image

The errors appear at line 445. The error comes from the fact that the pytree of params and vec do not match.

Hi @bolducke thanks for reporting this. Do you have an example that I can use for the repro? I plan to update the unit tests to cover that use case for both GN and LM.

@amir-saadat I was going to report this as well. I wrote a super simple test case to demonstrate, though I'm not sure its what you're looking for.

config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jaxopt

M = 5
params = jnp.zeros((2,M))
params = params.at[0].set(jnp.arange(M) * 1.0)
params = params.at[1].set(jnp.arange(M)**2 * 1.0)
params_dict = {'A': jnp.arange(M) * 1.0, 'B': jnp.arange(M)**2 * 1.0}

def F(params):
	return jnp.asarray([jnp.sum(params[0]), jnp.sum(params[0] * params[1]**2)])

def F_dict(params):
	return jnp.asarray([jnp.sum(params['A']), jnp.sum(params['A'] * params['B']**2)])

def optimize_F_gn(params, F):
	gn = jaxopt.GaussNewton(residual_fun=F)
	return gn.run(params).params

def optimize_F_lm(params, F):
	gn = jaxopt.LevenbergMarquardt(residual_fun=F)
	return gn.run(params).params

print(params)
print(optimize_F_gn(params, F)) 
print(optimize_F_gn(params_dict, F_dict))
print(optimize_F_lm(params, F)) # fails
print(optimize_F_lm(params_dict, F_dict)) # fails