lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Generation for PaLI?

BurgerAndreas opened this issue · comments

How would one generate an action (output text) using PaLI?

PaLI from readme.md

import torch
from x_transformers import ViTransformerWrapper, XTransformer, Encoder

# PaLI composes of
# 1. vision transformer (ViTransformerWrapper) +
# 2. encoder-decoder transformer (XTransformer)

vit = ViTransformerWrapper(
    image_size = 256,
    patch_size = 32,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

pali = XTransformer(
    dim = 512,
    enc_num_tokens = 256,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = 1024,
    dec_num_tokens = 256,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = 1024
)

# training data

img = torch.randn(1, 3, 256, 256)               # images
prompt = torch.randint(0, 256, (1, 1024))       # prompt
prompt_mask = torch.ones(1, 1024).bool()        # prompt text mask
output_text = torch.randint(0, 256, (1, 1024))  # target output text

# train

img_embeds = vit(
    img,
    return_embeddings = True
)

loss = pali(
    prompt,
    output_text,
    mask = prompt_mask,
    src_prepend_embeds = img_embeds             # will preprend image embeddings to encoder text embeddings before attention
)

loss.backward()

Desired behaviour

with torch.no_grad()
    vit.eval()
    pali.eval()

    img_embeds = vit(
        img,
        return_embeddings = True
    )
    
    # how to do this?
    # XTransformer.generate() does not take src_prepend_embeds that can be fed to encoder
    output_text = pali.generate(
        img_embeds,
        prompt,
        mask = prompt_mask,
    )

Idea?

img_embeds = self.vit(img=img, return_embeddings = True)

# from XTransformer.forward()
enc = pali.encoder(prompt, mask=prompt_mask, preprend_embeds=img_embeds, return_embeddings=True)
# from XTransformer.generate()
output_text = pali.decoder.generate(seq_out_start, seq_len, context=enc, context_mask=prompt_mask)