seungeunrho / minimalRL

Implementations of basic RL algorithms with minimal lines of codes! (pytorch based)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Improper asynchronous update in a3c

rahulptel opened this issue · comments

I doubt whether the asynchronous update made by the current a3c is adhering to what is suggested in the paper. Suppose the workers share the shared_model. Then each worker should:

  1. Copy the weight of the shared network into its local_model
  2. Runs for n steps or end of episode
  3. Calculate gradient
  4. Pass the gradient of the local_model to shared_model
  5. Update the shared_model and go to step 1

Thus, when the local_model is taking the n steps, it's weights do not change.

However, in the current implementation, we directly use the shared_model for taking those n steps. Hence, it may happen that some process P1 updates the weights of shared_model, which might affect the process P2. P2 might have started with some weight configuration of shared_model, which are now modified before those n steps are completed.

I think we can make the following change to the train method to avoid the above phenomenon:

def train(model):
    local_model = ActorCritic()
    local_model.load_state_dict(model.state_dict())
    # Create optimizer for the shared model
    # Create environment
    
    # Take n steps using local_model

    optimizer.zero_grad()
    # Calculate loss and get the gradients     
    loss_fn(local_model(data), labels).backward()

    for param, shared_param in zip(local_model.parameters(), model.parameters()):
        if shared_param.grad is not None:
            shared_param._grad = param.grad

        optimizer.step()

I am not very much familiar with the asynchronous model update in Pytorch but looking at the docs at https://pytorch.org/docs/stable/notes/multiprocessing.html#asynchronous-multiprocess-training-e-g-hogwild, I think we are using the shared_model all the time.

If you think what I say is correct, I can make a PR.

I think you made another proper claim.
As you noted, the weights of local_model changes in the current version of the implementation.
It should remain unchanged until the local model interacts with the environment for n_steps.
If you please revise the issue and make PR, I would be very thankful.
Thank you again.