google / trajax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Use of deprecated Jax APIs/behavior

pfrommerd opened this issue · comments

Trajax uses deprecated Jax APIs/behavior which result in warnings being emitted in two locations.

  1. The first instance is at
    @partial(jit, static_argnums=(0, 1, 9))

    where argnum 9 is specified although the function only has 8 arguments. This results in the warning:
jax/_src/api_util.py:165: SyntaxWarning: Jitted function has static_argnums=(0, 1, 9), but only accepts 8 positional arguments. 
This warning will be replaced by an error after 2022-08-20 at the earliest.
  1. The second warning is at
    K = -sp.linalg.solve(G_, H, sym_pos=True)

    which raises the following warning
trajax/tvlqr.py:98: FutureWarning: The sym_pos argument to solve() is deprecated and will be removed in a future JAX release. Use assume_a='pos' instead.

The second warning should be fixed in #10.