LeCAR-Lab / CoVO-MPC

Official implementation for the paper "CoVO-MPC: Theoretical Analysis of Sampling-based MPC and Optimal Covariance Design" accepted by L4DC 2024. CoVO-MPC is an optimal sampling-based MPC algorithm.

Home Page:https://lecar-lab.github.io/CoVO-MPC/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

🐛TypeError: unsupported operand type(s) for *: 'Symbol' and 'DynamicJaxprTracer'

bzx20 opened this issue · comments

commented

❓ Issue

When I wanted to check the taut dynamics for dual drones, I updated the new state only by taut dynamics, like this:

new_state = self.taut_dynamics(params, state, env_action)

Then I ran the quad_dual2d.py for test, and I got the following error:

Traceback (most recent call last):
  File "/home/bianzx/jax/quadjax/quadjax/quad_dual2d.py", line 641, in <module>
    main(tyro.cli(Args))
  File "/home/bianzx/jax/quadjax/quadjax/quad_dual2d.py", line 638, in main
    test_env(env, policy=random_policy, render_video=args.render)
  File "/home/bianzx/jax/quadjax/quadjax/quad_dual2d.py", line 424, in test_env
    next_obs, next_env_state, reward, done, info = env.step(
  File "/home/bianzx/miniconda3/envs/pytorch3d/lib/python3.9/site-packages/gymnax-0.0.6-py3.9.egg/gymnax/environments/environment.py", line 38, in step
    obs_st, state_st, reward, done, info = self.step_env(
  File "/home/bianzx/jax/quadjax/quadjax/quad_dual2d.py", line 87, in step_env
    new_state = self.taut_dynamics(params, state, env_action)
  File "/home/bianzx/jax/quadjax/quadjax/dynamics/taut_dual.py", line 212, in taut_dynamics
    b = b_taut_dyn_func(*params, *states, *action)
  File "<lambdifygenerated-2>", line 2, in _lambdifygenerated
TypeError: unsupported operand type(s) for *: 'Symbol' and 'DynamicJaxprTracer'

In taut_dual.py:

Define:

# lambda A_taut_dyn
  A_taut_dyn_func = sp.lambdify(
      params + states_val + action, A_taut_dyn, "jax")
  b_taut_dyn_func = sp.lambdify(
      params + states_val + action, b_taut_dyn, "jax")

Use:

 A = A_taut_dyn_func(*params, *states, *action)
 b = b_taut_dyn_func(*params, *states, *action)

It seems that lambda expression have some trouble with tracer in jax. But I did not encounter the problem before in single drone env.

🤔 Possible solutions

ChatGPT suggests:

A_taut_dyn_func = sp.lambdify(
    params + states_val + action, A_taut_dyn, "numpy")

b_taut_dyn_func = sp.lambdify(
    params + states_val + action, b_taut_dyn, "numpy")

def A_taut_dyn_jax(*args):
    return jnp.array(A_taut_dyn_func(*args)).astype(jnp.float32)

def b_taut_dyn_jax(*args):
    return jnp.array(b_taut_dyn_func(*args)).astype(jnp.float32)

But it did not work.

The issue you're encountering typically stems from an incomplete argument list provided to the lambdify function. This means that there are some parameters from the original symbolic expression that have been omitted.

For example, in taut_dual.py:43:

params = [m, I, g, l, mo, delta_yh, delta_zh]

It appears that a parameter has been overlooked.

    # y displacement of the hook from the quadrotor center
    delta_yh2 = sp.Symbol("delta_yh2")
    # z displacement of the hook from the quadrotor center
    delta_zh2 = sp.Symbol("delta_zh2")

To avoid this issue, ensure that all symbolic variables are included when lambdaifying the function. You can verify this by checking action, states_val, and action again.

commented

The issue you're encountering typically stems from an incomplete argument list provided to the lambdify function. This means that there are some parameters from the original symbolic expression that have been omitted.

For example, in taut_dual.py:43:

params = [m, I, g, l, mo, delta_yh, delta_zh]

It appears that a parameter has been overlooked.

    # y displacement of the hook from the quadrotor center
    delta_yh2 = sp.Symbol("delta_yh2")
    # z displacement of the hook from the quadrotor center
    delta_zh2 = sp.Symbol("delta_zh2")

To avoid this issue, ensure that all symbolic variables are included when lambdaifying the function. You can verify this by checking action, states_val, and action again.

I complete the argument list in params with the missing symbolic variables. And I also check params, states_val, and action.
Print them out in test.ipynb:
params

[m,
 I,
 g,
 l,
 mo,
 delta_yh,
 delta_zh,
 delta_yh2,
 delta_zh2,]

states_val

 [y(t),
 z(t),
 theta(t),
 phi(t),
 y_dot_val,
 z_dot_val,
 theta_dot_val,
 phi_dot_val,
 y2(t),
 z2(t),
 theta2(t),
 phi2(t),
 y2_dot_val,
 z2_dot_val,
 theta2_dot_val,
 phi2_dot_val]

action

[thrust(t),
 tau(t),
 thrust2(t),
 tau2(t)]

It appears that I have included all symbolic variables, but I still encounter the same error TypeError: unsupported operand type(s) for *: 'Symbol' and 'DynamicJaxprTracer'.

PS: When I use debug mode and set a checkpoint at b = b_taut_dyn_func(*params, *states, *action), the error is raised. However, the earlier expression A = A_taut_dyn_func(*params, *states, *action) works fine.

commented

This problem is actually caused by incomplete symbol variable substitution.