koulanurag / muzero-pytorch

Pytorch Implementation of MuZero

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Problems with Ray version

DominikRoB opened this issue · comments

I noticed that in the requirements.txt file ray 1.13.0 is given. But the code is not compatible with this version, notable during training here:

ray.wait(workers, len(workers))

wait() only accepts one positional parameter and even if I use ray.wait(workers, num_returns=len(workers)), I get the following error:

wait() expected a list of ray.ObjectRef, got list containing <class 'ray.actor.ActorHandle'>

Still searching for a solution, but if somebody has a good idea, that would be awesome.

Found a solution:

def train(config, summary_writer=None):
    [...]
    workers = [DataWorker.remote(rank, config, storage, replay_buffer)
               for rank in range(0, config.num_actors)]

    waiting_fors = []
    for worker in workers:
        waiting_fors += [worker.run.remote()]
    waiting_fors += [_test.remote(config, storage)]
    _train(config, storage, replay_buffer, summary_writer)
    ray.wait(waiting_fors, num_returns=len(waiting_fors))

    return config.get_uniform_network().set_weights(ray.get(storage.get_weights.remote()))