Put together a TNT class
Rishit-dagli opened this issue · comments
Rishit Dagli commented
Verify shapes:
tnt = TNT(
image_size=256, # size of image
patch_dim=512, # dimension of patch token
pixel_dim=24, # dimension of pixel token
patch_size=16, # patch size
pixel_size=4, # pixel size
depth=5, # depth
num_classes=1000, # output number of classes
attn_dropout=0.1, # attention dropout
ff_dropout=0.1, # feedforward dropout
)
img = tf.random.uniform(shape=[1, 3, 256, 256])
print(tnt(img).shape)
# (1, 1000)
```