google-deepmind / mctx

Monte Carlo tree search in JAX

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Incompatible types error with increased global precision

AbhiDu96 opened this issue · comments

Hello,
Thank you for this amazing library. It's been a while I am using this package for our project. I have started facing the following issue:
'''
/home/dubey/anaconda3/envs/jaxenv_qiskit/lib/python3.11/site-packages/jax/_src/ops/scatter.py:93: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32. In future JAX releases this will result in an error.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
'''
I want increased precision for my experiments and so I set the following global variable as:
'''
jax.config.update("jax_enable_x64", True)
'''
Unfortunately, I have tracked down the bug to
gumbel_mu_zero_policy function
and within these functions the

  1. action_selection package - Line 140 explicitly sets the dtype to "jnp.int32" which I believe creates the above issue.
  2. search package - Line 172, 288, 364, 367 and 375 also explicitly set the dtype to "int32".

I am not sure how to resolve this issue but kindly let me know if this does not make sense.

Update: I tried keeping "int32" everywhere but as expected that does not make sense since when you specify the global config to x64 it does change precisions of everything.

Thanks.
With jax_enable_x64, the argmax output was int64. I changed the code to use the expected int32.

The implementation is not tested with jax_enable_x64. Check that your results are sensible.

Hello @fidlej ,
Thanks for the solution. With the corresponding changes it works without any warnings or errors. Although I still have to check if the results with tax_enable_x64 are sensible. I think we can close the issue with this.
Thanks again.