AILab-CVC / SEED

Official implementation of SEED-LLaMA (ICLR 2024).

Home Page:https://ailab-cvc.github.io/seed

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to force model to generate image?

haochuan-li opened this issue · comments

Hi! Great work.

image

I see there's a "force image generation" option in the gradio demo.
I wonder how to implement this in code? Can anyone enlighten me on this?

Thanks.

Sorry for the late reply. Force image generation can be achieved by manually adding BOI token (Begin of image). The code can be found in the following link:

input_text += BOI_TOKEN

Thanks for the reply!

@sijeh, I have another question related to the zero-shot retrieval evaluation. I cannot reproduce Table1 results in SEED-LLaMA paper.

Here's my code preparing Text Embedding and Image Embedding for Flickr30k

"""
Setting: Using Seed-LLaMA Tokenizer 2
"""
import hydra
from omegaconf import OmegaConf
from lavis.models import load_model
device = 'cuda'

tokenizer_cfg_path = 'configs/tokenizer/seed_llama_tokenizer_hf.yaml'
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
seed_tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=device, load_diffusion=False)

"""Preparing Flickr Text Embedding, simply follow blip2 retrieval"""
blip2_model = load_model("blip2", "pretrain")
blip2_model.eval().to(device)

text_emb = []
blip_text = blip2_model.tokenizer(captions, padding='max_length', truncation=True, max_length=32, return_tensors='pt')

blip_dataset = TextDataset(blip_text)
blip_dataloader = DataLoader(blip_dataset, 
                                shuffle=False, 
                                drop_last=False, 
                                num_workers=8,
                                pin_memory=True, 
                                batch_size=args.batch_size)

for (input_ids, attention_mask) in tqdm(blip_dataloader, desc='text', unit='text'):
    qformer_output = blip2_model.Qformer.bert(input_ids.to(device), attention_mask=attention_mask.to(device), return_dict=True).last_hidden_state[:,0,:]
    text_emb.append(qformer_output.detach().cpu())
text_emb = torch.concat(text_emb) # Text Emb for Retrieval, shape=[5000, 768]


"""Preparing Flickr Image Embedding"""
causal_code_pt = []
causal_emb_pt = []
for im in tqdm(imgs_gt, desc="tokenizing img", unit='img'):
    _, causal_code, causal_emb = seed_tokenizer.encode_image(image_torch=transform(im).to(device))
    causal_code_pt.append(causal_code[0][-1].squeeze())  # take the final embedding
    causal_emb_pt.append(causal_emb[0][-1].squeeze()) # take the final embedding

causal_code_pt = torch.stack(causal_code_pt) # Causal Code For Retrieval, shape=[1000, 768]
causal_emb_pt = torch.stack(causal_emb_pt) # Causal Emb For Retrieval, shape=[1000,768]

"""
The Detail about how to get causal code and causal emb, 
I modified the code in models/seed_qformer/qformer_quantizer.py
"""

def get_codebook_indices(self, image):
    with torch.no_grad():
        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
        print("image embeds", image_embeds.shape) # [1,257,1408]
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        ) 
        # query_output hidden shape=[1,32,768]
        # query output down shape=[1,32,32]
        # query output up shape=[1,32,768]
    
        query_output_down = self.encode_task_layer(query_output.last_hidden_state)
        quant, loss_embed, embed_ind = self.quantize(query_output_down)
        embed_ind = embed_ind.reshape(quant.shape[0], -1)
        
        query_output_up = self.decode_task_layer(quant)
    return embed_ind, query_output_up, query_output.last_hidden_state

"""Compute Similarity Matrix"""
causal_code /= causal_code.norm(dim=-1, keepdim=True)
causal_emb /= causal_emb.norm(dim=-1, keepdim=True)

blip_causal_code_sim = (text_emb @ causal_code.T) 
blip_causal_emb_sim = (text_emb @ causal_emb.T)

Results in paper

image

Reproduced Results

image

Question

image

I'm not sure whether this is the right way to get the text embedding and image embedding illustrated in the SEED-LLaMA paper. Please correct me if I'm wrong.

Looking Forward to your reply.

Thanks