magic-research / PLLaVA

Official repository for the paper PLLaVA

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to eval using HF?

QiaoZhennn opened this issue · comments

# Load model directly
from transformers import AutoProcessor, AutoModelForSeq2SeqLM

processor = AutoProcessor.from_pretrained("ermu2001/pllava-34b")
model = AutoModelForSeq2SeqLM.from_pretrained("ermu2001/pllava-34b")

I try to load the model using this demo code. But it shows the following error. Wonder is there any example on how to run inference using hugging face?
Unrecognized configuration class <class 'transformers.models.llava.configuration_llava.LlavaConfig'> for this kind of AutoModel: AutoModelForSeq2SeqLM. Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, GPTSanJapaneseConfig, LEDConfig, LongT5Config, M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, NllbMoeConfig, PegasusConfig, PegasusXConfig, PLBartConfig, ProphetNetConfig, SeamlessM4TConfig, SeamlessM4Tv2Config, SwitchTransformersConfig, T5Config, UMT5Config, XLMProphetNetConfig.

my transformers version is 4.39.2

Hi,

The model weights we've uploaded is formatted with transformers peft lora. Such that doesn't supports directly loading with this transformers auto loading code yet. To load our model, you should probably check out this function in our code for reference. Using this function you should be able to load the model with PeftLanguageModel.

def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha=32, use_multi_gpus=False, pooling_shape=(16,12,12)):
kwargs = {
'num_frames': num_frames,
}
# print("===============>pooling_shape", pooling_shape)
if num_frames == 0:
kwargs.update(pooling_shape=(0,12,12)) # produce a bug if ever usen the pooling projector
config = PllavaConfig.from_pretrained(
repo_id if not use_lora else weight_dir,
pooling_shape=pooling_shape,
**kwargs,
)
with torch.no_grad():
model = PllavaForConditionalGeneration.from_pretrained(repo_id, config=config, torch_dtype=torch.bfloat16)
try:
processor = PllavaProcessor.from_pretrained(repo_id)
except Exception as e:
processor = PllavaProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf')
# config lora
if use_lora and weight_dir is not None:
print("Use lora")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False, target_modules=["q_proj", "v_proj"],
r=128, lora_alpha=lora_alpha, lora_dropout=0.
)
print("Lora Scaling:", lora_alpha/128)
model.language_model = get_peft_model(model.language_model, peft_config)
assert weight_dir is not None, "pass a folder to your lora weight"
print("Finish use lora")
# load weights
if weight_dir is not None:
state_dict = {}
save_fnames = os.listdir(weight_dir)
if "model.safetensors" in save_fnames:
use_full = False
for fn in save_fnames:
if fn.startswith('model-0'):
use_full=True
break
else:
use_full= True
if not use_full:
print("Loading weight from", weight_dir, "model.safetensors")
with safe_open(f"{weight_dir}/model.safetensors", framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
else:
print("Loading weight from", weight_dir)
for fn in save_fnames:
if fn.startswith('model-0'):
with safe_open(f"{weight_dir}/{fn}", framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if 'model' in state_dict.keys():
msg = model.load_state_dict(state_dict['model'], strict=False)
else:
msg = model.load_state_dict(state_dict, strict=False)
print(msg)
# dispatch model weight
if use_multi_gpus:
max_memory = get_balanced_memory(
model,
max_memory=None,
no_split_module_classes=["LlamaDecoderLayer"],
dtype='bfloat16',
low_zero=False,
)
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=["LlamaDecoderLayer"],
dtype='bfloat16'
)
dispatch_model(model, device_map=device_map)
print(model.hf_device_map)
model = model.eval()
return model, processor

By the way, If you wish to run demo, you could execute this script.

If you want to evaluate our model directly, you could start following instructions here and prepare the data, then execute this script