ahmadreza9 / sac-discrete.pytorch

A PyTorch implementation of SAC-Discrete.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SAC-Discrete in PyTorch

This is a PyTorch implementation of SAC-Discrete[1]. I tried to make it easy for readers to understand the algorithm. Please let me know if you have any questions.

UPDATE

  • 2020.5.10
    • Refactor codes and fix a bug of SAC-Discrete algorithm.
    • Implement Prioritized Experience Replay[2], N-step Return and Dueling Networks[3].
    • Test them.

Setup

If you are using Anaconda, first create the virtual environment.

conda create -n sacd python=3.7 -y
conda activate sacd

You can install Python liblaries using pip.

pip install -r requirements.txt

If you're using other than CUDA 10.2, you may need to install PyTorch for the proper version of CUDA. See instructions for more details.

Examples

You can train SAC-Discrete agent like this example here.

python train.py --config config/sacd.yaml --env_id MsPacmanNoFrameskip-v4 --cuda --seed 0

If you want to use Prioritized Experience Replay(PER), N-step return or Dueling Networks, change use_per, multi_step or dueling_net respectively.

Results

I just evaluated vanilla SAC-Discrite, with PER, N-step Return or Dueling Networks in MsPacmanNoFrameskip-v4. The graph below shows the test returns along with environment steps (which equals environment frames divided by the factor of 4). Also, note that curves are smoothed by exponential moving average with weight=0.5 for visualization.

N-step Return and PER seems helpful to better utilize RL signals (e.g. sparse rewards).

References

[1] Christodoulou, Petros. "Soft Actor-Critic for Discrete Action Settings." arXiv preprint arXiv:1910.07207 (2019).

[2] Schaul, Tom, et al. "Prioritized experience replay." arXiv preprint arXiv:1511.05952 (2015).

[3] Wang, Ziyu, et al. "Dueling network architectures for deep reinforcement learning." arXiv preprint arXiv:1511.06581 (2015).

About

A PyTorch implementation of SAC-Discrete.

License:MIT License


Languages

Language:Python 100.0%