google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

Home Page:https://jaxopt.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BoxOSQP does not work without equality constraints

jewillco opened this issue · comments

Trying to run:

optimizer = jaxopt.BoxOSQP()
optimizer.run(
    params_obj=(
        jnp.eye(30, dtype=jnp.float32),
        jnp.ones((30,), dtype=jnp.float32),
    ),
    params_ineq=(-1, 1),
)

produces an error (some parts redacted):

[.../jaxopt/_src/osqp.py](...) in run(self, init_params, params_obj, params_eq, params_ineq)
    763       init_params = self.init_params(None, params_obj, params_eq, params_ineq)
    764 
--> 765     return super().run(init_params, params_obj, params_eq, params_ineq)
    766 
    767   def l2_optimality_error(

[.../jaxopt/_src/base.py](...) in run(self, init_params, *args, **kwargs)
    345       run = decorator(run)
    346 
--> 347     return run(init_params, *args, **kwargs)
    348 
    349   def __post_init__(self):

[.../jaxopt/_src/implicit_diff.py](...) in wrapped_solver_fun(*args, **kwargs)
    249     args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250     keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251     return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
    252 
    253   return wrapped_solver_fun

[.../jaxopt/_src/implicit_diff.py](...) in solver_fun_flat(*flat_args)
    205     def solver_fun_flat(*flat_args):
    206       args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207       return solver_fun(*args, **kwargs)
    208 
    209     def solver_fun_fwd(*flat_args):

[.../jaxopt/_src/base.py]() in _run(self, init_params, *args, **kwargs)
    287            *args,
    288            **kwargs) -> OptStep:
--> 289     state = self.init_state(init_params, *args, **kwargs)
    290 
    291     # We unroll the very first iteration. This allows `init_val` and `body_fun`

[.../jaxopt/_src/osqp.py](...) in init_state(self, init_params, params_obj, params_eq, params_ineq)
    456     A    = self.matvec_A(params_eq)
    457 
--> 458     primal_residuals, dual_residuals = self._compute_residuals(Q, c, A, x, z, y)
    459     solver_state = self._eq_qp_solve_impl.init_state(x, Q_params, params_eq,
    460                                                      self.sigma, self.rho_start)

[.../jaxopt/_src/osqp.py](...) in _compute_residuals(self, Q, c, A, x, z, y)
    530     Ax, ATy = A.matvec_and_rmatvec(x, y)
    531     primal_residuals = tree_sub(Ax, z)
--> 532     dual_residuals = tree_add(tree_add(Q(x), c), ATy)
    533     return primal_residuals, dual_residuals
    534 

TypeError: unsupported operand type(s) for +: 'ArrayImpl' and 'NoneType'

Hi!
Indeed, BoxOSQP formulation does not make sense without equality constraints.
I recall below the formulation solved by OSQP:

$$\min_x \frac{1}{2}x^TQx+c^Tx\text{ s.t }l\leq Ax\leq u.$$

I recall the API:

  • params_obj contains either (Q, c), either (params_Q,c) where params_Q is the parameters of the matvec, either params_fun which are the parameters of a function fun promised to be quadratic.
  • if you specify l and u, you must specify A, otherwise the problem doesn't have any sense. Check the doc
  • params_eq contains either A, either (params_A) where params_A is the parameters of the matvec

I am not sure what was your intent by not specifying A. Did you expect the identity function? Then use matvec_A=(lambda _, x: x) and pass params_eq=None.

I did get confused on the API, but I would not have expected a crash like that. Would a better error message be possible? Also, does the problem formulation with an arbitrary $A$ mean that this can handle an arbitrary set of linear constraints, not just a box?

Also, does the problem formulation with an arbitrary mean that this can handle an arbitrary set of linear constraints, not just a box?

Yes, it is the most general formulation that one can think of. Maybe you are more familiar with:

$$\min_x\frac{1}{2}x^TQx+c^Tx\text{ s.t }Ax=b,Gx\leq h$$

Now, observe that:

$$Ax=b\implies b\leq Ax\leq b$$

and

$$Gx\leq h\implies -\infty\leq Gx\leq h$$

so you can implement it as:

def matvec_A_boxosqp(params_A, x):
   A, G = params_A
   return A @ x, G @ x

l = (b, jnp.full(shape=h.shape, fill_value=-float('inf')))
u = (b, h)
params_ineq = (l, u)
params_eq = (A, G)

solver = BoxOSQP(matvec_A=matvec_A_boxosqp)
init_params = solver.init_params(init_x=init_x)
solver.run(init_params, params_obj=(Q, c), params_eq=params_eq, params_ineq=params_ineq)

This reduction between problem formulations is performed automatically in OSQP solver, but its API is higher level so it is less flexible.

Would a better error message be possible?

I agree we should improve the readability of errors messages.

What I was trying to get at is to determine whether that means that arbitrary linear constraints are supported, not just box constraints. Does that mean this class is a full QP solver, and not just for box-constrained problems?

Does that mean this class is a full QP solver, and not just for box-constrained problems?

Yes, it is! Just like the original's one: https://osqp.org/