strange design
efengx opened this issue Β· comments
π Describe the bug
run:
accelerate launch examples/ppo_sentiments_llama.py
appear:
ValueError: Unsupported architecture: LLaMAForCausalLM
. The following architectures are available for model branching:
['GPTJForCausalLM', 'GPT2LMHeadModel', 'GPTNeoForCausalLM',
'GPTNeoXForCausalLM', 'OPTForCausalLM', 'BloomModel', 'BloomForCausalLM',
'LlamaModel', 'LlamaForCausalLM']
track:
models/modeling_ppo.py > hf_get_branch_class :1290
code:
https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py
question:
The model architectures in huggingface will change frequently, why the model identification here adopts a static hard-coded method?
gpt_branch_supported_archs = [
"GPTJForCausalLM",
"GPT2LMHeadModel",
"GPTNeoForCausalLM",
"GPTNeoXForCausalLM",
]
opt_branch_supported_archs = ["OPTForCausalLM"]
bloom_branch_supported_archs = ["BloomModel", "BloomForCausalLM"]
llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"]
arch = config.architectures[0]
if arch in gpt_branch_supported_archs:
return GPTModelBranch
elif arch in opt_branch_supported_archs:
return OPTModelBranch
elif arch in bloom_branch_supported_archs:
return BloomModelBranch
elif arch in llama_branch_supported_archs:
return LlamaModelBranch
else:
all_supported_archs = sum(
[
gpt_branch_supported_archs,
opt_branch_supported_archs,
bloom_branch_supported_archs,
llama_branch_supported_archs,
],
[],
)
raise ValueError(
f"Unsupported architecture: `{arch}`. The following architectures are "
f"available for model branching:\n{all_supported_archs}"
)
Maybe it will be more reliable to use parameters and model_type?
skip this question for now:
llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM"]
changed to
llama_branch_supported_archs = ["LlamaModel", "LlamaForCausalLM", "LLaMAForCausalLM"]
Which trlX version are you using?
main
Additional system and package information
trlx/requirements.txt