vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)

Home Page:http://docs.cleanrl.dev

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add Polyak update to DQN

manjavacas opened this issue · comments

Problem Description

Checklist

Current Behavior

Currently, DQN implementation do a hard update of the target network. However, it is possible to perform soft updates by using a soft update coefficient, between 0 and 1 (Polyak update).

Expected Behavior

Soft updates can increase the stability of learning, as detailed in the original DDPG paper. This is because the target values are constrained to change slowly.

Although this idea came after the original implementation of DQN, it is equally applicable to this algorithm.

Finally, this is a solution implemented in other reference libraries such as StableBaselines3, although I would understand that it is not intended to be added for simplicity and adherence to the original DQN implementation.

Possible Solution

In the current DQN implementation, substitute:

    # update the target network
    if global_step % args.target_network_frequency == 0:
        target_network.load_state_dict(q_network.state_dict())

by:

    # update the target network
    if global_step % args.target_network_frequency == 0:
        for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
            target_network_param.data.copy_(
                polyak_coeff * q_network_param.data + (1. - polyak_coeff) * target_network_param.data)

Hi thanks for raising this issue. This sounds like a good idea, especially since we are already doing polyak updates in

q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))
(optax docs on optax.incremental_update).

Feel free to make a PR.