Batch inference support?
OuYangg opened this issue · comments
I measured the inference speed of the MMICL model on a V100.32G and it was about 1.5 seconds per image. Does the MMICL model support batch inference?
Absolutely. It supports both batch training and batch inference. You only need to pad the input images of each instance to the same length and then apply the respective image mask to obscure the padded images.
Here is the code for batch inference:
from torch.nn.functional import pad
def padd_images( image, max_length):
image = torch.tensor(image)
mask = torch.zeros(max_length).bool()
pad_len = max_length - image.shape[0]
mask[:image.shape[0]] = True
image = pad(image,(0,0,0,0,0,0,0,pad_len)) # padding behind the first dim
return image,mask
image = Image.open ("images/chinchilla.png")
image1 = Image.open ("images/shiba.png")
image2 = Image.open ("images/flamingo.png")
image4 = Image.open ("images/shiba.png")
image5 = Image.open ("images/flamingo.png")
images =[
[image,image1,image2], [image4 ,image5]
]
prompt = [f'image 0 is <image0>{replace_token},image 1 is <image1>{replace_token},image 2 is <image2>{replace_token}. Question: <image0> is a chinchilla. They are mainly found in Chile.\n Question: <image1> is a shiba. They are very popular in Japan.\nQuestion: image 2 is',
f'image 0 is <image0>{replace_token}, image 0 is a shiba. They are very popular in Japan.\n image 1 is <image1>{replace_token}, image 1 is a',
]
max_image_length = max([len(f) for f in images ])
inputs = processor( text=prompt, return_tensors="pt",padding=True)
pixel_values= [ processor(images=img, return_tensors="pt")['pixel_values'] for img in images]
image_list=[]
mask_list= []
for img in pixel_values:
image,img_mask = padd_images(img,max_image_length)
image_list.append(image)
mask_list.append(img_mask)
inputs['pixel_values'] = torch.stack(image_list).to(torch.bfloat16)
inputs['img_mask'] = torch.stack(mask_list)
inputs = inputs.to('cuda:1')
outputs = model.generate(
pixel_values = inputs['pixel_values'],
input_ids = inputs['input_ids'],
attention_mask = inputs['attention_mask'],
img_mask = inputs['img_mask'],
do_sample=False,
max_length=50,
min_length=1,
set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)
print(generated_text)
Absolutely. It supports both batch training and batch inference. You only need to pad the input images of each instance to the same length and then apply the respective image mask to obscure the padded images.
Here is the code for batch inference:
from torch.nn.functional import pad def padd_images( image, max_length): image = torch.tensor(image) mask = torch.zeros(max_length).bool() pad_len = max_length - image.shape[0] mask[:image.shape[0]] = True image = pad(image,(0,0,0,0,0,0,0,pad_len)) # padding behind the first dim return image,mask image = Image.open ("images/chinchilla.png") image1 = Image.open ("images/shiba.png") image2 = Image.open ("images/flamingo.png") image4 = Image.open ("images/shiba.png") image5 = Image.open ("images/flamingo.png") images =[ [image,image1,image2], [image4 ,image5] ] prompt = [f'image 0 is <image0>{replace_token},image 1 is <image1>{replace_token},image 2 is <image2>{replace_token}. Question: <image0> is a chinchilla. They are mainly found in Chile.\n Question: <image1> is a shiba. They are very popular in Japan.\nQuestion: image 2 is', f'image 0 is <image0>{replace_token}, image 0 is a shiba. They are very popular in Japan.\n image 1 is <image1>{replace_token}, image 1 is a', ] max_image_length = max([len(f) for f in images ]) inputs = processor( text=prompt, return_tensors="pt",padding=True) pixel_values= [ processor(images=img, return_tensors="pt")['pixel_values'] for img in images] image_list=[] mask_list= [] for img in pixel_values: image,img_mask = padd_images(img,max_image_length) image_list.append(image) mask_list.append(img_mask) inputs['pixel_values'] = torch.stack(image_list).to(torch.bfloat16) inputs['img_mask'] = torch.stack(mask_list) inputs = inputs.to('cuda:1') outputs = model.generate( pixel_values = inputs['pixel_values'], input_ids = inputs['input_ids'], attention_mask = inputs['attention_mask'], img_mask = inputs['img_mask'], do_sample=False, max_length=50, min_length=1, set_min_padding_size =False, ) generated_text = processor.batch_decode(outputs, skip_special_tokens=True) print(generated_text)
It works, thx~