AF2_metrics jax error
JonathanEAsh opened this issue · comments
Hello Jue,
I am trying to run af2_metrics.py on my relaxed models, but I'm getting this error:
File "/projects/f_sdk94_1/PDZ_complex/RFDesign_env/SE3-nvidia/lib/python3.9/site-packages/haiku/_src/transform.py", line 264, in check_not_jax_transformed
if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)): # pytype: disable=name-error
AttributeError: module 'jaxlib.xla_extension' has no attribute 'CompiledFunction'
I have built the conda environment using the yml file, but I also downgraded jax and jaxlib as specified in an earlier issue.
jax 0.2.17
jaxlib 0.1.69
Running af2_metrics using the versions of jax and jaxlib specified in the yml file yielded a similar AttributeError, where jax.ops had no attribute 'index_add'. Any help in solving this issue would be greatly appreciated.
Thanks!
Jonathan