araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] TQC Hyperparameter optimization: Results do not match the reference. This is likely a bug/unexpected loss of precision.

edmund735 opened this issue Β· comments

πŸ› Bug

Hi,

When I try to run TQC hyperparameter optimization with multiple jobs (n-jobs>1) with a GPU (this also happens with multiple CPU cores and n-jobs=1), it gives me this error:

2024-04-07 14:35:59.992779: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 0: -inf, expected -0.000287323
2024-04-07 14:35:59.992804: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 1: -inf, expected -0.000267224
2024-04-07 14:35:59.992808: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 2: -inf, expected -0.000226477
2024-04-07 14:35:59.992811: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 3: -inf, expected -0.000281823
2024-04-07 14:35:59.992813: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 4: -inf, expected -0.000262532
2024-04-07 14:35:59.992815: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 5: -inf, expected -0.000252724
2024-04-07 14:35:59.992818: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 6: -inf, expected -0.000250007
2024-04-07 14:35:59.992820: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 7: -inf, expected -0.000265674
2024-04-07 14:35:59.992823: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 8: -inf, expected -0.00021464
2024-04-07 14:35:59.992825: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 9: -inf, expected -0.000204733
E0407 14:35:59.992828  798907 triton_autotuner.cc:766] Results do not match the reference. This is likely a bug/unexpected loss of precision.

To Reproduce

python rl-baselines3-zoo/train_sbx.py --algo tqc --env Pendulum-v1 -n 5000 --n-trials 50 --num-threads 1 --n-jobs 4 --log-interval 4900 --eval-episodes 16 --n-eval-envs 8 --seed 8 --vec-env "dummy" -optimize --sampler tpe --pruner median --n-startup-trials 10
[W 2024-04-07 14:36:00,208] Trial 16 failed with parameters: {'gamma': 0.995, 'learning_rate': 0.23149128592335125, 'batch_size': 1024, 'buffer_size': 10000, 'learning_starts': 1000, 'train_freq': 16, 'tau': 0.08, 'log_std_init': -0.3684256821552643, 'net_arch': 'medium', 'n_quantiles': 32, 'top_quantiles_to_drop_per_net': 30} because of the following error: XlaRuntimeError('INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.  Per-algorithm errors:\n  Results do not match the reference. This is likely a bug/unexpected loss of precision.

Traceback (most recent call last):
  File "/home/.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
    model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)  # type: ignore[arg-type]
  File "/home/.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
    return super().learn(
  File "/home/.conda/envs/.../lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
  File "/home/.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train
    ) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.


Traceback (most recent call last):
File "/scratch/network/.../.../rl-baselines3-zoo/train_sbx.py", line 19, in <module>
train()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/train.py", line 275, in train
exp_manager.hyperparameters_optimization()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 874, in hyperparameters_optimization
study.optimize(self.objective, n_jobs=self.n_jobs, n_trials=self.n_trials)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/study.py", line 451, in optimize
_optimize(
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 99, in _optimize
f.result()
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 159, in _optimize_sequential
frozen_trial = _run_trial(study, func, catch)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 247, in _run_trial
raise func_err
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
value_or_values = func(trial)
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type]
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
return super().learn(
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train
) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.

### System Info

Describe the characteristic of your environment:

  • Library installed through pip

  • GPU models and configuration
    +---------------------------------------------------------------------------------------+
    | NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
    |-----------------------------------------+----------------------+----------------------+
    | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
    | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
    | | | MIG M. |
    |=========================================+======================+======================|
    | 0 NVIDIA A100 80GB PCIe On | 00000000:0D:00.0 Off | 0 |
    | N/A 40C P0 67W / 300W | 3508MiB / 81920MiB | 0% Default |
    | | | Disabled |
    +-----------------------------------------+----------------------+----------------------+
    | 1 NVIDIA A100 80GB PCIe On | 00000000:B5:00.0 Off | 0 |
    | N/A 38C P0 49W / 300W | 5MiB / 81920MiB | 0% Default |
    | | | Disabled |
    +-----------------------------------------+----------------------+----------------------+

  • Python 3.10.14

  • pytorch 2.2.2 py3.10_cuda12.1_cudnn8.9.2_0
    pytorch-cuda 12.1 ha16c6d3_5 pytorch
    pytorch-mutex 1.0 cuda pytorch
    torchtriton 2.2.0 py310 pytorch

  • Gym version
    gymnasium 0.29.1

  • Versions of any other relevant libraries
    jax 0.4.25 pyhd8ed1ab_0 conda-forge
    jax-jumpy 1.0.0 pyhd8ed1ab_0 conda-forge
    jaxlib 0.4.23 cuda118py310h8c47008_200 conda-forge

Additional context

I've noticed there's no bug when n-jobs=1, only when running multiple jobs. Maybe something with the way Optuna runs multiple jobs?

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

I did this again with an new version of jax (jaxlib 0.4.23 cuda120py310h3cc97ca_20) and it gives a new error now:
'''
[I 2024-04-07 15:55:02,691] A new study created in memory with name: no-name-0da03417-c265-43a2-a55b-10d9750abcca
2024-04-07 15:55:23.576297: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture
2024-04-07 15:55:23.576337: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CaptureGpuGraph failed (Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current tracing scope: triton_gemm_dot.86): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture
2024-04-07 15:55:23.576394: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: CaptureGpuGraph failed (Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current tracing scope: triton_gemm_dot.86): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current profiling annotation: XlaModule:#prefix=jit(_train)/jit(main)/while/body,hlo_module=jit__train,program_id=116#.
INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: CaptureGpuGraph failed (Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current tracing scope: triton_gemm_dot.86): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current profiling annotation: XlaModule:#prefix=jit(_train)/jit(main)/while/body,hlo_module=jit__train,program_id=116#.

Sampled hyperparams:
{'batch_size': 1024,
'buffer_size': 100000,
'ent_coef': 'auto',
'gamma': 0.9999,
'gradient_steps': 1,
'learning_rate': 0.004315216575412321,
'learning_starts': 0,
'policy_kwargs': {'log_std_init': -1.4239746627852474,
'n_quantiles': 31,
'net_arch': [64, 64],
'top_quantiles_to_drop_per_net': 25,
'use_sde': False},
'target_entropy': 'auto',
'tau': 0.02,
'top_quantiles_to_drop_per_net': 25,
'train_freq': 1}
2024-04-07 15:55:23.577027: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1883] could not synchronize on CUDA context: CUDA_ERROR_STREAM_CAPTURE_UNSUPPORTED: operation not permitted when stream is capturing :: *** Begin stack trace ***
_PyObject_MakeTpCall

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyObject_MakeTpCall
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall


_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault

PyObject_Call
_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
clone

*** End stack trace ***

[I 2024-04-07 15:55:23,577] Trial 1 pruned.
[W 2024-04-07 15:55:23,606] Trial 3 failed with parameters: {'gamma': 1, 'learning_rate': 0.03739146141228411, 'batch_size': 256, 'buffer_size': 100000, 'learning_starts': 1000, 'train_freq': 1, 'tau': 0.02, 'log_std_init': -1.1735175685607313, 'net_arch': 'small', 'n_quantiles': 26, 'top_quantiles_to_drop_per_net': 24} because of the following error: XlaRuntimeError('INTERNAL: Failed to synchronize GPU for autotuning.').
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
value_or_values = func(trial)
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type]
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
return super().learn(
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train
) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to synchronize GPU for autotuning.

[W 2024-04-07 15:55:23,607] Trial 3 failed with value None.
[I 2024-04-07 15:55:44,276] Trial 2 finished with value: -149.14224087500003 and parameters: {'gamma': 0.995, 'learning_rate': 0.005830150992686316, 'batch_size': 2048, 'buffer_size': 10000, 'learning_starts': 0, 'train_freq': 8, 'tau': 0.01, 'log_std_init': -3.101106181907312, 'net_arch': 'medium', 'n_quantiles': 13, 'top_quantiles_to_drop_per_net': 1}. Best is trial 2 with value: -149.14224087500003.
[I 2024-04-07 15:55:44,442] Trial 0 finished with value: -1286.9111508125 and parameters: {'gamma': 0.995, 'learning_rate': 0.03380452664776398, 'batch_size': 128, 'buffer_size': 1000000, 'learning_starts': 1000, 'train_freq': 4, 'tau': 0.02, 'log_std_init': -3.2686941182290763, 'net_arch': 'big', 'n_quantiles': 45, 'top_quantiles_to_drop_per_net': 13}. Best is trial 2 with value: -149.14224087500003.
[I 2024-04-07 15:55:44,610] Trial 4 finished with value: -408.14151849999996 and parameters: {'gamma': 0.99, 'learning_rate': 0.022024554072114278, 'batch_size': 512, 'buffer_size': 100000, 'learning_starts': 1000, 'train_freq': 8, 'tau': 0.02, 'log_std_init': 0.9307981026739451, 'net_arch': 'big', 'n_quantiles': 35, 'top_quantiles_to_drop_per_net': 29}. Best is trial 2 with value: -149.14224087500003.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/scratch/network/.../.../rl-baselines3-zoo/train_sbx.py", line 19, in
train()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/train.py", line 275, in train
exp_manager.hyperparameters_optimization()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 874, in hyperparameters_optimization
study.optimize(self.objective, n_jobs=self.n_jobs, n_trials=self.n_trials)
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/study.py", line 451, in optimize
_optimize(
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 99, in _optimize
f.result()
File "/home/.../.conda/envs/...1/lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/.../.conda/envs/...1/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/.../.conda/envs/...1/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 159, in _optimize_sequential
frozen_trial = _run_trial(study, func, catch)
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 247, in _run_trial
raise func_err
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
value_or_values = func(trial)
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type]
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
return super().learn(
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train

) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to synchronize GPU for autotuning.

'''

This might be related to Jax not handling multi-threading/multi-processing well.

You should probably have a look at distributed tuning using a shared database (I would recommend the log format): https://rl-baselines3-zoo.readthedocs.io/en/master/guide/tuning.html