jannerm / ddpo

Code for the paper "Training Diffusion Models with Reinforcement Learning"

Home Page:https://rl-diffusion.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

error on 4090+cuda12.1

chaojiewang94 opened this issue · comments

WARNING:jax.experimental.compilation_cache.compilation_cache:Initialized persistent compilation cache at cache
[ utils/logger ] Suppressing most dependency logging
2023-06-11 16:58:17.065316: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:417] Loaded runtime CuDNN library: 8.3.2 but source was compiled with: 8.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
File "/home/chaojiewang/Desktop/Logical_Diffusion/Diffusion_RL/pipeline/policy_gradient.py", line 484, in
main()
File "/home/chaojiewang/Desktop/Logical_Diffusion/Diffusion_RL/pipeline/policy_gradient.py", line 51, in main
rng = jax.random.PRNGKey(args.seed)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/random.py", line 136, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 561, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
return seed(seeds)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/prng.py", line 813, in threefry_seed
lax.shift_right_logical(seed, lax_internal._const(seed, 32)))
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 458, in shift_right_logical
return shift_right_logical_p.bind(x, y)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 519, in compile_or_get_cached
compiled = backend_compile(backend, serialized_computation,
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/chaojiewang/anaconda3/envs/ddpo-gpu/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

It looks like you are using the GPU dependencies from #3, right? While I'm super glad that this exists, unfortunately we (the authors) don't have the GPUs to be able to test it, so we won't be able to provide support for CuDNN issues.

@kvablack @jannerm yeah, the dependecies I used are consistant with those in #3 , but find the aforementioned issue. I wonder if anyone else will relesea the GPU version of code so that we can follow up your work?

@kvablack @jannerm I think it is caused by the version of pyTorch, even the lastest pyTorch can not support cudnn8.8 or higher versions. So, I am quite curious how your reproduce this work on GPU

Hi,

I am able to train it on an A100-80GB after doing the following modifications to the yml file:

name: ddpo-gpu
channels:

  • defaults
  • conda-forge
  • pytorch
    dependencies:
  • python=3.9
  • pip
  • pip:
    • -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    • --extra-index-url https://download.pytorch.org/whl/cu118
    • imageio == 2.22.4
    • imageio-ffmpeg >= 0.4.3
    • scikit-video == 1.1.11
    • torch
    • torchvision
    • diffusers==0.12.1
    • transformers == 4.28.1
    • numpy >= 1.20.2
    • tqdm >= 4.60.0
    • flax == 0.6.9
    • optax == 0.1.5
    • jax[cuda11_pip] == 0.4.8
    • jaxlib == 0.4.7
    • optax >= 0.0.6
    • typed-argument-parser == 1.7.2
    • gitpython == 3.1.29
    • google-cloud-storage == 2.6.0
    • gcsfs == 2022.11.0
    • fsspec == 2022.11.0
    • h5py == 3.7.0
    • datasets == 2.7.1
    • matplotlib == 3.6.2
    • latex == 0.7.0
    • inflect == 6.0.4
    • gsutil

Essentially PyTorch needs to be updated to the newest version with CUDA 11.8 and jax needs to change from cuda12_pip to cuda11_pip. Also a 4090 might be a bit too small of a GPU. Also needed to add gsutil.

PyTorch + LoRA coming soon! See my comment on #2.

Update: You may have to update the inflect package to 6.2.0 or later!

Hey everyone, I'm going to close this issue as I would recommend moving over to the PyTorch version that's designed for GPUs here.