Rishit-dagli / Transformer-in-Transformer

An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Home Page:https://pypi.org/project/tnt-tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Put together a TNT class

Rishit-dagli opened this issue · comments

Verify shapes:

tnt = TNT(
    image_size=256,  # size of image
    patch_dim=512,  # dimension of patch token
    pixel_dim=24,  # dimension of pixel token
    patch_size=16,  # patch size
    pixel_size=4,  # pixel size
    depth=5,  # depth
    num_classes=1000,  # output number of classes
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1,  # feedforward dropout
)

img = tf.random.uniform(shape=[1, 3, 256, 256])
print(tnt(img).shape)

# (1, 1000)
```