haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.

Home Page:https://llava.hliu.cc

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Usage] Must I reload the model when I want to inference on a new image?

lin-whale opened this issue · comments

commented

Describe the issue

I think the time to load model is very long, so try to reuse the model when inferring in a new image. But encounter the issue below, so is it possible to do this? How should I write the code?
Modified from llava/serve/cli.py

def main(args):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    ...

    while True:
        image_file = input("Please input image path:")
        image = load_image(image_file)
        image_size = image.size
        # Similar operation in model_worker.py
        image_tensor = process_images([image], image_processor, model.config)
        if type(image_tensor) is list:
            image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
        else:
            image_tensor = image_tensor.to(model.device, dtype=torch.float16)

        while True:
            try:
                inp = input(f"{roles[0]}: ")
            except EOFError:
                inp = ""
            if not inp:
                print("exit...")
                break

            print(f"{roles[1]}: ", end="")

            if image is not None:
                # first message
                if model.config.mm_use_im_start_end:
                    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                else:
                    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                image = None
            
            conv.append_message(conv.roles[0], inp)
            conv.append_message(conv.roles[1], None)
            prompt = conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
            stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
            keywords = [stop_str]
            streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

            with torch.inference_mode():
                output_ids = model.generate(
                    input_ids,
                    images=image_tensor,
                    image_sizes=[image_size],
                    do_sample=True if args.temperature > 0 else False,
                    temperature=args.temperature,
                    max_new_tokens=args.max_new_tokens,
                    streamer=streamer,
                    use_cache=True)

            outputs = tokenizer.decode(output_ids[0]).strip()
            conv.messages[-1][-1] = outputs

            if args.debug:
                print("\n", {"prompt": prompt, "outputs": outputs}, "\n")

The code works well on first image input, but fails on the second image input.

Output:

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:29<00:00,  1.94s/it]
Please input image path:/home/aistar/llava/data/view.jpg
<|im_start|>user
: hello
<|im_start|>assistant
: Hello! This is a beautiful image of a wooden dock extending into a serene lake. The calm water reflects the surrounding landscape, which includes a forest and mountains in the distance. The sky is partly cloudy, suggesting a pleasant day. The dock appears to be a quiet spot for relaxation or perhaps a starting point for boating or fishing. It's a peaceful scene that evokes a sense of tranquility and connection with nature.
<|im_start|>user
: 
exit...
Please input image path:/home/aistar/llava/data/view.jpg
<|im_start|>user
: hello
<|im_start|>assistant
: Traceback (most recent call last):
  File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/aistar/llava/LLaVA/llava/serve/cli_multi_turn.py", line 137, in <module>
    main(args)
  File "/home/aistar/llava/LLaVA/llava/serve/cli_multi_turn.py", line 107, in main
    output_ids = model.generate(
  File "/home/aistar/llava/annaconda3/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/aistar/llava/LLaVA/llava/model/language_model/llava_llama.py", line 125, in generate
    ) = self.prepare_inputs_labels_for_multimodal(
  File "/home/aistar/llava/LLaVA/llava/model/llava_arch.py", line 260, in prepare_inputs_labels_for_multimodal
    cur_image_features = image_features[cur_image_idx]
IndexError: list index out of range