fatchord / WaveRNN

WaveRNN Vocoder + TTS

Home Page:https://fatchord.github.io/model_outputs/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to use multi-GPU?

keto33 opened this issue · comments

I have two GPUs (1080 Ti and 1060). During the training process, I got the error:

Traceback (most recent call last):
  File "train_tacotron.py", line 202, in <module>
    main()
  File "train_tacotron.py", line 98, in main
    tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example)
  File "train_tacotron.py", line 132, in tts_train_loop
    m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m)
  File "/home/keto/WaveRNN/utils/__init__.py", line 32, in data_parallel_workaround
    outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
  File "/home/keto/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/home/keto/.local/lib/python3.8/site-packages/torch/_utils.py", line 395, in reraise
    raise self.exc_type(msg)
StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/keto/.local/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/keto/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/keto/WaveRNN/models/tacotron.py", line 311, in forward
    device = next(self.parameters()).device  # use same device as parameters
StopIteration

I believe the problem is related to the distribution of tasks among devices because I could resolve the problem by modifying train_tacotron.py file

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == 'cuda' and torch.cuda.device_count() > 1:
                m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m)
            else:
                m1_hat, m2_hat, attention = model(x, m)

I used the else code when for count 2.

My second GPU is much weaker and cannot contribute much, but I thought it might be useful to report the issue.

I think it's a Pytorch Bug

huggingface/transformers#3936

Maybe you'd better downgrade your torch package I think?

to use pytorch1.4.0 may solve this problem

As mentioned in CorentinJ/Real-Time-Voice-Cloning#664, using torch version 1.4 dosen't work. The error i got is:
"AttributeError: 'PosixPath' object has no attribute 'tell'"
I googled it and find that to solve it i have to use torch version above 1.6.
Awkward face...

to use pytorch1.4.0 may solve this problem