allenai / allennlp

An open-source NLP research library, built on PyTorch.

Home Page:http://www.allennlp.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Training a transformer classifier with DDP

david-waterworth opened this issue · comments

I've pre-trained my own Huggingface Roberta transformer and I'm fine-tuning a classifier with it, I'm using the following components

    embedder: {
        token_embedders: {
            tokens: {
                type: "pretrained_transformer",
                model_name: MODEL_NAME,
                transformer_kwargs: {
                    add_pooling_layer: false
                }
            }
        },
    },
    encoder: {
        type: "pass_through",
        input_dim: MODEL_DIM
    },
    pooler: {
        type: "bert_pooler",
        pretrained_model: MODEL_NAME
    },

This works fine if I train on a single GPU however, it fails when I try to use DDP - I get errors about unused parameters not contributing to the loss.

Looking through the code, it seems to be related to the pooler. When you load a RobertaForMaskedLM using AutoModel.from_pretrained it drops the language model head and adds a pooler layer.

Looking at the PretrainedTransformerEmbedder code, it loads the model from the cache (copy=True) including the pooler - however in forward the embeddings are extracted before the pooler.

In BertPooler however, it loads the model from the cache (copy= False) and then deep copies the pooler

This means there are two poolers and the first doesn't contribute to gradients if I'm reading things correctly.

Also if I pass transformer_kwargs: { add_pooling_layer: false } to the embedder there's no pooler at all and the bert_pooler throws an exception.

Is this not the intended way of using the pooler?

As an aside, oddly torch recommends setting find_unused_parameters=true which I assumed was diagnosis but it actually seems to fix the problem?

I met the similar error. I changed following find_unused_parameters to True and the error was gone.

def __init__(
self,
*,
find_unused_parameters: bool = False,
local_rank: Optional[int] = None,
world_size: Optional[int] = None,
cuda_device: Union[torch.device, int] = -1,
) -> None:
super().__init__(local_rank=local_rank, world_size=world_size, cuda_device=cuda_device)
self._ddp_kwargs = {
"find_unused_parameters": find_unused_parameters,
}

Though I don't know how to configure it.

@huhk-sysu I managed to configure find_unused_parameters as follows

distributed: {
    cuda_devices: if NUM_GPUS > 1 then std.range(0, NUM_GPUS - 1) else 0,
    ddp_accelerator: {
        type: "torch",
        find_unused_parameters: true
    }
},

By explicitly adding ddp_accelerator you can set the parameters - otherwise, as you show in the code above it creates a default with find_unused_parameters=False.

@huhk-sysu I managed to configure find_unused_parameters as follows

distributed: {
    cuda_devices: if NUM_GPUS > 1 then std.range(0, NUM_GPUS - 1) else 0,
    ddp_accelerator: {
        type: "torch",
        find_unused_parameters: true
    }
},

By explicitly adding ddp_accelerator you can set the parameters - otherwise, as you show in the code above it creates a default with find_unused_parameters=False.

Thanks, it works for me.

if NUM_GPUS > 1 then std.range(0, NUM_GPUS - 1) else 0

@david-waterworth

It may fail when NUM_GPUS == 1. Maybe following is better:

distributed: if NUM_GPUS > 1 then {
    cuda_devices: std.range(0, NUM_GPUS - 1),
    ddp_accelerator: {
        type: "torch",
        find_unused_parameters: true
    }
}

This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇