magic-research / PLLaVA

Official repository for the paper PLLaVA

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Finetune my own's tasks based on pllava-7b, but got "image_to_overwrite.sum() != image_features.shape[:-1].numel()" assertion failure.

gaowei724 opened this issue · comments

Thank you for your work. I'd like to use your code to train on my own dataset and test its precision. When I follow the guide in the repo and use llava-v1.6-vicuna-7b-hf as the pre-trained model to execute python tasks/train/train_pllava_nframe_accel.py tasks/train/config_pllava_nframe.py model.repo_id MODELS/llava-v1.6-vicuna-7b-hf, it trains normally and outputs loss. However, when I want to train using your already finetuned model pllava-7b and execute python tasks/train/train_pllava_nframe_accel.py tasks/train/config_pllava_nframe.py model.repo_id MODELS/pllava-7b, it throws an error at the assertion in the _merge_input_ids_with_image_features function: "image_to_overwrite.sum() != image_features.shape[:-1].numel()". I suspect that the generated image_to_overwrite is incorrect, but why is this happening? How should I load pllava-7b or pllava-13b to continue training?

image

This fixed the problem #26 (comment)