CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

strange design

efengx opened this issue Β· comments

commented

πŸ› Describe the bug

run:
accelerate launch examples/ppo_sentiments_llama.py

appear:
ValueError: Unsupported architecture: LLaMAForCausalLM. The following architectures are available for model branching:
['GPTJForCausalLM', 'GPT2LMHeadModel', 'GPTNeoForCausalLM',
'GPTNeoXForCausalLM', 'OPTForCausalLM', 'BloomModel', 'BloomForCausalLM',
'LlamaModel', 'LlamaForCausalLM']

track:
models/modeling_ppo.py > hf_get_branch_class :1290

code:
https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py

question:
The model architectures in huggingface will change frequently, why the model identification here adopts a static hard-coded method?

gpt_branch_supported_archs = [
        "GPTJForCausalLM",
        "GPT2LMHeadModel",
        "GPTNeoForCausalLM",
        "GPTNeoXForCausalLM",
    ]
    opt_branch_supported_archs = ["OPTForCausalLM"]
    bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"]
    llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"]
    arch = config.architectures[0]
    if arch in gpt_branch_supported_archs:
        return GPTModelBranch
    elif arch in opt_branch_supported_archs:
        return OPTModelBranch
    elif arch in bloom_branch_supported_archs:
        return BloomModelBranch
    elif arch in llama_branch_supported_archs:
        return LlamaModelBranch
    else:
        all_supported_archs = sum(
            [
                gpt_branch_supported_archs,
                opt_branch_supported_archs,
                bloom_branch_supported_archs,
                llama_branch_supported_archs,
            ],
            [],
        )
        raise ValueError(
            f"Unsupported architecture: `{arch}`. The following architectures are "
            f"available for model branching:\n{all_supported_archs}"
        )

Maybe it will be more reliable to use parameters and model_type?

skip this question for now:

llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"]
changed to
llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM", "LLaMAForCausalLM"]

Which trlX version are you using?

main

Additional system and package information

trlx/requirements.txt

commented

This is a good point, however the difficulty still revolves around hydra heads for reference policy and value network and the need to have branch classes for them for each architecture. The issue at hand was resolved by transformers, see #476