google / trajax

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CEM

mkolodziejczyk-piap opened this issue · comments

Hi, I noticed there are some problems with running CEM.

  1. Here the last index should be 7:

    @partial(jit, static_argnums=(0, 1, 9))

    (similar in random shooting:
    @partial(jit, static_argnums=(0, 1, 9))
    )

  2. Default hyperparameters should be frozendict
    https://github.com/google/trajax/blob/main/trajax/optimizers.py#L685-L692

Would you like a PR for this?

Regards,

Good catch! Feel free to send a PR.

I suspect that I might be having a similar issue with using the cem and random_shooting methods. I receive the error:

Traceback (most recent call last):
  File "XX/trajax/tests/optimizers_test.py", line 850, in testCEM1
    X_opt, U_opt, obj, = optimizers.random_shooting(
  File "XX/trajax/trajax/optimizers.py", line 1160, in random_shooting
    controls = gaussian_samples(random_key, mean, stdev, control_low,
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 5) of type <class 'dict'> for function gaussian_samples is non-hashable.

when trying to use random_shooting. I will investigate and make a PR for this, if I can fix it quickly enough.

I'll attach the method that I added to OptimizersTest to demonstrate this behavior.
testCEM1.txt