johnypark / CCT-keras

Compact Transformers implemented in keras

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CCT-keras: Compact Transformers implemented in keras

Implementation of Compact Transformers from Escaping the Big Data Paradigm with Compact Transformers

The official Pytorch implementation can be found here: https://github.com/SHI-Labs/Compact-Transformers

Shi_Labs_Comparison

Overview

Compact Convolutional Transformer (CCT) is represented by three main changes on ViT:

  • Convolutional Tokenizer, instead of the direct image patching of ViT
  • Sequence Pooling instead of the Class Token
  • Learnable Positional Embedding instead of Sinusodial Embedding

CCT naturally inherits other components of ViT, such as:

  • Multi-Head Self Attention
  • Feed Forward Network (MLP Block)
  • Dropouts and Stochastic Depth

Usage

!pip install git+https://github.com/johnypark/CCT-keras

from CCT_keras import CCT

model = CCT(num_classes = 1000, input_shape = (224, 224, 3))

The default CCT() is set as CCT_14_7x2 in the paper, for which the authors used to train on ImageNet from scratch.


model = summary()
.
.
.
 layer_normalization_26 (LayerN  (None, 196, 384)    768         ['add_25[0][0]']                 
 ormalization)                                                                                    
                                                                                                  
 multi_head_self_attention_13 (  (None, None, 384)   591360      ['layer_normalization_26[0][0]'] 
 MultiHeadSelfAttention)                                                                          
                                                                                                  
 drop_path_26 (DropPath)        (None, None, 384)    0           ['multi_head_self_attention_13[0]
                                                                 [0]']                            
                                                                                                  
 add_26 (Add)                   (None, 196, 384)     0           ['add_25[0][0]',                 
                                                                  'drop_path_26[0][0]']           
                                                                                                  
 layer_normalization_27 (LayerN  (None, 196, 384)    768         ['add_26[0][0]']                 
 ormalization)                                                                                    
                                                                                                  
 feed_forward_network_13 (FeedF  (None, 196, 384)    886272      ['layer_normalization_27[0][0]'] 
 orwardNetwork)                                                                                   
                                                                                                  
 drop_path_27 (DropPath)        (None, 196, 384)     0           ['feed_forward_network_13[0][0]']
                                                                                                  
 add_27 (Add)                   (None, 196, 384)     0           ['add_26[0][0]',                 
                                                                  'drop_path_27[0][0]']           
                                                                                                  
 layer_normalization_28 (LayerN  (None, 196, 384)    768         ['add_27[0][0]']                 
 ormalization)                                                                                    
                                                                                                  
 dense (Dense)                  (None, 196, 1)       385         ['layer_normalization_28[0][0]'] 
                                                                                                  
 tf.linalg.matmul (TFOpLambda)  (None, 1, 384)       0           ['dense[0][0]',                  
                                                                  'layer_normalization_28[0][0]'] 
                                                                                                  
 flatten (Flatten)              (None, 384)          0           ['tf.linalg.matmul[0][0]']       
                                                                                                  
 dropout_1 (Dropout)            (None, 384)          0           ['flatten[0][0]']                
                                                                                                  
 dense_1 (Dense)                (None, 1000)         385000      ['dropout_1[0][0]']              
                                                                                                  
==================================================================================================
Total params: 24,735,401
Trainable params: 24,735,401
Non-trainable params: 0
__________________________________________________________________________________________________

Access Model Weights

model_weights_dict = {(w.name): (idx, w.dtype, w.shape) for idx, w in enumerate(model.weights)}
names_dense = [name for name in model_weights_dict.keys() if 'dense' in name]
idx_dense = [model_weights_dict[name][0] for name in names_dense]


>>model_weights_dict

{'conv2d/kernel:0': (0, tf.float32, TensorShape([7, 7, 3, 192])),
 'conv2d_1/kernel:0': (1, tf.float32, TensorShape([7, 7, 192, 384])),
 'layer_normalization/gamma:0': (2, tf.float32, TensorShape([384])),
 'layer_normalization/beta:0': (3, tf.float32, TensorShape([384])),
 'multi_head_self_attention/dense_query/kernel:0': (4,
  tf.float32,
  TensorShape([384, 384])),
 'multi_head_self_attention/dense_query/bias:0': (5,
  tf.float32,
  TensorShape([384])),
 'multi_head_self_attention/dense_key/kernel:0': (6,
  tf.float32,
  TensorShape([384, 384])),
 'multi_head_self_attention/dense_key/bias:0': (7,
  tf.float32,
  TensorShape([384])),
 'multi_head_self_attention/dense_value/kernel:0': (8,
  tf.float32,
  TensorShape([384, 384])),
 'multi_head_self_attention/dense_value/bias:0': (9,
  tf.float32,
  TensorShape([384])),
 'multi_head_self_attention/dense_out/kernel:0': (10,
  tf.float32,
  TensorShape([384, 384])),
 'multi_head_self_attention/dense_out/bias:0': (11,
  tf.float32,
  TensorShape([384])),
 'layer_normalization_1/gamma:0': (12, tf.float32, TensorShape([384])),
 'layer_normalization_1/beta:0': (13, tf.float32, TensorShape([384])),
 'feed_forward_network/dense_hidden/kernel:0': (14,
  tf.float32,
  TensorShape([384, 1152])),
 'feed_forward_network/dense_hidden/bias:0': (15,
  tf.float32,
  TensorShape([1152])),
 'feed_forward_network/dense_out/kernel:0': (16,
  tf.float32,
  TensorShape([1152, 384])),
 'feed_forward_network/dense_out/bias:0': (17, tf.float32, TensorShape([384])),
 'layer_normalization_2/gamma:0': (18, tf.float32, TensorShape([384])),
 'layer_normalization_2/beta:0': (19, tf.float32, TensorShape([384])),

Results and Pre-trained Weights

Results and weights are adpoted directly from the official PyTorch implementation (https://github.com/SHI-Labs/Compact-Transformers). I plan to gradually port the PyTorch weights to Tensorflow and keep things posted here. Type can be read in the format L/PxC where L is the number of transformer layers, P is the patch/convolution size, and C (CCT only) is the number of convolutional layers.

CIFAR-10 and CIFAR-100

Model Pretraining Epochs PE Source CIFAR-10 CIFAR-100
CCT-7/3x1 None 300 Learnable Official Pytorch 96.53% 80.92%
CCT-keras TBD TBD
1500 Sinusoidal Official Pytorch 97.48% 82.72%
CCT-keras TBD TBD
5000 Sinusoidal Official Pytorch 98.00% 82.87%
CCT-keras TBD TBD

Flowers-102

Model Pre-training PE Image Size Source Accuracy
CCT-7/7x2 None Sinusoidal 224x224 Official Pytorch 97.19%
CCT-keras TBD
CCT-14/7x2 ImageNet-1k Learnable 384x384 Official Pytorch 99.76%
CCT-keras TBD

ImageNet

</tbody>
Model Type Resolution Epochs # Params MACs Source Top-1 Accuracy
ViT 12/16 384 300 86.8M 17.6G Offical Pytorch 77.91%
CCT 14/7x2 224 310 22.36M 5.11G Offical Pytorch 80.67%
CCT-keras TBD
14/7x2 384 310 + 30 22.51M 15.02G Offical Pytorch 82.71%
CCT-keras TBD

About

Compact Transformers implemented in keras

License:Apache License 2.0


Languages

Language:Python 100.0%