zhunge / ConformerViT

ViT + Conformer = ( ͡❛ ͜ʖ ͡❛)👌

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Conformer-ViT

This repository extends Vision Transformers and Conformers for Image Classification and Image2Seq tasks such as OCR and Captioning.

Usage - Image2Seq

from conformer_vit import ConformerViTForImage2Seq
import torch

model = ConformerViTForImage2Seq(
    image_size=256,
    patch_size=16,
    num_classes=150,
    dim=320,
    depth=12,
    heads=8,
    decoder_dim=640,
    output_seq_len=128,
    decoder_type="transformer",
    SOS_token=1,
    EOS_token=2,
    channels=1,
    dropout=0.1,
    emb_dropout=0.1,
    kernel_size=17,
    causal=False
)

inp = torch.randn(1, 1, 64, 256)
target_seq = torch.randint(0, 150, (1, 128))
pred = model(inp, target_seq=target_seq, teacher_forcing_ratio=0.5)
print(pred.shape) # (1, 128, 150)

Usage - Image Classification

from conformer_vit import ConformerViTForClassification
import torch

model = ConformerViTForClassification(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=144,
    depth=12,
    heads=16,
    dropout=0.1,
    emb_dropout=0.1
)

img = torch.randn(1, 3, 256, 256)

preds = model(img)  # (1, 1000)

Acknowledgement

Code for Decoder borrowed from here

About

ViT + Conformer = ( ͡❛ ͜ʖ ͡❛)👌


Languages

Language:Python 100.0%