HaozheZhao / MIC

MMICL, a state-of-the-art VLM with the in context learning ability from ICL, PKU

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Batch inference support?

OuYangg opened this issue · comments

commented

I measured the inference speed of the MMICL model on a V100.32G and it was about 1.5 seconds per image. Does the MMICL model support batch inference?

Absolutely. It supports both batch training and batch inference. You only need to pad the input images of each instance to the same length and then apply the respective image mask to obscure the padded images.

Here is the code for batch inference:

from torch.nn.functional import pad
def padd_images( image, max_length):
    image = torch.tensor(image)
    mask = torch.zeros(max_length).bool()
    pad_len = max_length - image.shape[0]
    mask[:image.shape[0]] = True
    image = pad(image,(0,0,0,0,0,0,0,pad_len)) # padding behind the first dim
    return image,mask

image = Image.open ("images/chinchilla.png")
image1 = Image.open ("images/shiba.png")
image2 = Image.open ("images/flamingo.png")

image4 = Image.open ("images/shiba.png")
image5 = Image.open ("images/flamingo.png")

images =[ 
[image,image1,image2], [image4 ,image5]
]

prompt = [f'image 0 is <image0>{replace_token},image 1 is <image1>{replace_token},image 2 is <image2>{replace_token}. Question: <image0> is a chinchilla. They are mainly found in Chile.\n Question: <image1> is a shiba. They are very popular in Japan.\nQuestion: image 2 is',
f'image 0 is <image0>{replace_token}, image 0 is a shiba. They are very popular in Japan.\n image 1 is <image1>{replace_token}, image 1 is a',
]

max_image_length = max([len(f) for f in images ])


inputs = processor( text=prompt, return_tensors="pt",padding=True) 

pixel_values= [ processor(images=img, return_tensors="pt")['pixel_values'] for img in images]

image_list=[]
mask_list= []
for img in pixel_values:
    image,img_mask = padd_images(img,max_image_length)
    image_list.append(image)
    mask_list.append(img_mask)
inputs['pixel_values'] = torch.stack(image_list).to(torch.bfloat16)
inputs['img_mask'] = torch.stack(mask_list)
inputs = inputs.to('cuda:1')
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=50,
        min_length=1,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)
print(generated_text)

commented

Absolutely. It supports both batch training and batch inference. You only need to pad the input images of each instance to the same length and then apply the respective image mask to obscure the padded images.

Here is the code for batch inference:

from torch.nn.functional import pad
def padd_images( image, max_length):
    image = torch.tensor(image)
    mask = torch.zeros(max_length).bool()
    pad_len = max_length - image.shape[0]
    mask[:image.shape[0]] = True
    image = pad(image,(0,0,0,0,0,0,0,pad_len)) # padding behind the first dim
    return image,mask

image = Image.open ("images/chinchilla.png")
image1 = Image.open ("images/shiba.png")
image2 = Image.open ("images/flamingo.png")

image4 = Image.open ("images/shiba.png")
image5 = Image.open ("images/flamingo.png")

images =[ 
[image,image1,image2], [image4 ,image5]
]

prompt = [f'image 0 is <image0>{replace_token},image 1 is <image1>{replace_token},image 2 is <image2>{replace_token}. Question: <image0> is a chinchilla. They are mainly found in Chile.\n Question: <image1> is a shiba. They are very popular in Japan.\nQuestion: image 2 is',
f'image 0 is <image0>{replace_token}, image 0 is a shiba. They are very popular in Japan.\n image 1 is <image1>{replace_token}, image 1 is a',
]

max_image_length = max([len(f) for f in images ])


inputs = processor( text=prompt, return_tensors="pt",padding=True) 

pixel_values= [ processor(images=img, return_tensors="pt")['pixel_values'] for img in images]

image_list=[]
mask_list= []
for img in pixel_values:
    image,img_mask = padd_images(img,max_image_length)
    image_list.append(image)
    mask_list.append(img_mask)
inputs['pixel_values'] = torch.stack(image_list).to(torch.bfloat16)
inputs['img_mask'] = torch.stack(mask_list)
inputs = inputs.to('cuda:1')
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=50,
        min_length=1,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)
print(generated_text)

It works, thx~