alirezakazemipour / TD3-PyTorch

Addressing Function Approximation Error in Actor-Critic Methods

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PRs Welcome

TD3-PyTorch

TD3 (Twin Delayed Deep Deterministic Policy Gradient) is somehow the equivalent of Double Q-Learning in Continuous Domain. To avoid overestimation existed in DDPG method, a similar trick from Double Q-Learning is used: Disentangle Action Selection part from Action Evaluation in the Bellman Equation by using two separate networks.
TD3 also uses some tricks (like Target Policy Smoothing ) to reduce high variance common in Policy Gradient methods.

Demo

Ant Hopper

Results

x-axis: episode number.

Ant Hopper

Environmnets tested

  • Pendulum-v0
  • Hopper-v2
  • Ant-v2
  • HalfCheetah-v2

Dependencies

  • gym == 0.17.3
  • mujoco-py == 2.0.2.13
  • numpy == 1.19.2
  • opencv_contrib_python == 4.4.0.44
  • psutil == 5.5.1
  • torch == 1.6.0

Installation

pip3 install -r requirements.txt

Usage

usage: main.py [-h] [--mem_size MEM_SIZE] [--env_name ENV_NAME]
               [--interval INTERVAL] [--do_train]

Variable parameters based on the configuration of the machine or user's choice

optional arguments:
  -h, --help           show this help message and exit
  --mem_size MEM_SIZE  The memory size.
  --env_name ENV_NAME  Name of the environment.
  --interval INTERVAL  The interval specifies how often different parameters
                       should be saved and printed, counted by episodes.
  --do_train           The flag determines whether to train the agent or play
                       with it.
  • In order to train the agent with default arguments , execute the following command (You may change the memory capacity and the environment based on your desire.):
python3 main.py --mem_size=700000 --env_name="Ant-v2" --do_train
  • Remove do_train flag to change the mode from training to testing.
python3 main.py --env_name="Ant-v2" # <- Test the agent

References

Acknowledgement

About

Addressing Function Approximation Error in Actor-Critic Methods

License:GNU General Public License v3.0


Languages

Language:Python 98.8%Language:Shell 1.2%