RosettaCommons / RFDesign

Protein hallucination and inpainting with RoseTTAFold

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AF2_metrics jax error

JonathanEAsh opened this issue · comments

Hello Jue,

I am trying to run af2_metrics.py on my relaxed models, but I'm getting this error:

File "/projects/f_sdk94_1/PDZ_complex/RFDesign_env/SE3-nvidia/lib/python3.9/site-packages/haiku/_src/transform.py", line 264, in check_not_jax_transformed
if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)): # pytype: disable=name-error
AttributeError: module 'jaxlib.xla_extension' has no attribute 'CompiledFunction'

I have built the conda environment using the yml file, but I also downgraded jax and jaxlib as specified in an earlier issue.
jax 0.2.17
jaxlib 0.1.69

Running af2_metrics using the versions of jax and jaxlib specified in the yml file yielded a similar AttributeError, where jax.ops had no attribute 'index_add'. Any help in solving this issue would be greatly appreciated.

Thanks!
Jonathan