leondgarse / Keras_insightface

Insightface Keras implementation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

VPL - Variational Prototype Learning for deep face recognition

abdikaiym01 opened this issue · comments

Do you have a plan to implement their(insightface) last work? It seems does not work as they claimed in their papers. deepinsight/insightface#1801

Is it the current SOTA for the face recognition task?

I think an imlementation is possible, will try and check if will be better.

VPL mode is added. It can be enabled by tt = train.Train(..., use_vpl=True). I think it should be same with the official implementation. Here is my test results using basic EfficientNetV2B0 + AdamW:

import losses, train, models
import tensorflow_addons as tfa
keras.mixed_precision.set_global_policy("mixed_float16")

data_basic_path = '/datasets/ms1m-retinaface-t1'
data_path = data_basic_path + '_112x112_folders'
eval_paths = [os.path.join(data_basic_path, ii) for ii in ['lfw.bin', 'cfp_fp.bin', 'agedb_30.bin']]

from keras_cv_attention_models import efficientnet
basic_model = efficientnet.EfficientNetV2B0(input_shape=(112, 112, 3), num_classes=0)
basic_model = models.buildin_models(basic_model, dropout=0, emb_shape=512, output_layer='GDC', bn_epsilon=1e-4, bn_momentum=0.9, scale=True, use_bias=False)

tt = train.Train(data_path, eval_paths=eval_paths,
    save_path='TT_efv2_b0_swish_GDC_arc_emb512_dr0_adamw_5e4_bs512_ms1mv3_randaug_cos16_batch_float16_vpl.h5',
    basic_model=basic_model, model=None, lr_base=0.01, lr_decay=0.5, lr_decay_steps=16, lr_min=1e-6, lr_warmup_steps=3,
    batch_size=512, random_status=100, eval_freq=4000, output_weight_decay=1, use_vpl=True)

import tensorflow_addons as tfa
optimizer = tfa.optimizers.AdamW(learning_rate=1e-2, weight_decay=5e-4, exclude_from_weight_decay=["/gamma", "/beta"])

sch = [
    {"loss": losses.ArcfaceLoss(scale=16), "epoch": 4, "optimizer": optimizer},
    {"loss": losses.ArcfaceLoss(scale=32), "epoch": 3},
    {"loss": losses.ArcfaceLoss(scale=64), "epoch": 46},
]
tt.train(sch, 0)
exit()

Plot
Selection_458
Results

VPL lfw cfp_fp agedb_30 IJBB 1e-4 IJBC 1e-4
False 0.997667 0.979429 0.978333 0.941188 0.955719
True 0.997667 0.979571 0.978500 0.938559 0.955054

In the vpl paper the results were absolutely different. It turns out the reality quite another. What do you think about it? Maybe your implementation a little different than their, though I'm not sure about it?

Ya, I have compared them several times. It seems the main parts are:

  • insightface/prepare_queue_lambda
      def prepare_queue_lambda(self, label, iters):
          self.queue_lambda[:] = 0.0
          if iters>self.cfg['start_iters']:
              allowed_delta = self.cfg['allowed_delta']
              if self.vpl_mode==0:
                  past_iters = iters - self.queue_iters
                  idx = torch.where(past_iters <= allowed_delta)[0]
                  self.queue_lambda[idx] = self.cfg['lambda']
    Here it's models.py#L247
    queue_lambda = tf.cond(
        self.iters > self.start_iters,
        lambda: tf.where(self.iters - self.queue_iters <= self.allowed_delta, self.vpl_lambda, 0.0),  # prepare_queue_lambda
        lambda: self.zero_queue_lambda,
    )
  • insightface vpl_mode>=0
    ...
    _lambda = self.queue_lambda.view(self.num_local, 1)
    injected_weight = norm_weight*(1.0-_lambda) + self.queue*_lambda
    injected_norm_weight = normalize(injected_weight)
    ...
    Here models.py#L254
    norm_w = K.l2_normalize(self.w, axis=0)
    injected_weight = norm_w * (1 - queue_lambda) + tf.transpose(self.queue_features) * queue_lambda
    injected_norm_weight = K.l2_normalize(injected_weight, axis=0)
  • insightface/set_queue
    def set_queue(self, total_features, total_label, index_positive, iters):
        local_label = total_label[index_positive]
        sel_features = normalize(total_features[index_positive,:])
        self.queue[local_label,:] = sel_features
        self.queue_iters[local_label] = iters
    Here in myCallbacks/VPLUpdateQueue
    class VPLUpdateQueue(keras.callbacks.Callback):
        def __init__(self):
            super().__init__()
    
        def on_batch_end(self, batch, logs=None):
            batch_labels_back_up = self.model.loss[0].batch_labels_back_up
            update_label_pos = tf.expand_dims(batch_labels_back_up, 1)
            vpl_norm_dense_layer = self.model.layers[-1]
    
            updated_queue = tf.tensor_scatter_nd_update(vpl_norm_dense_layer.queue_features, update_label_pos, vpl_norm_dense_layer.norm_features)
            vpl_norm_dense_layer.queue_features.assign(updated_queue)
    
            iters = tf.repeat(vpl_norm_dense_layer.iters, tf.shape(batch_labels_back_up)[0])
            updated_queue_iters = tf.tensor_scatter_nd_update(vpl_norm_dense_layer.queue_iters, update_label_pos, iters)
            vpl_norm_dense_layer.queue_iters.assign(updated_queue_iters)
    As it needs the true labels, have to do this update outside of model. I think the logic is matching with official one, or if I'm missing something...
  • In the official example_ms1mv3.py config, they are using 'start_iters': 8000, 'allowed_delta': 200 for batch_size = 128. I'm using that for batch_size=512, may try start_iters=8000 / 4, allowed_delta=200 / 4 for batch_size=512 later.

Here is the result using start_iters=8000 / 4, allowed_delta=200 / 4 for batch_size=512:
Selection_458

Results

VPL lfw cfp_fp agedb_30 IJBB 1e-4 IJBC 1e-4
False 0.997667 0.979429 0.978333 0.941188 0.955719
start 8000, delta 200 0.997667 0.979571 0.978500 0.938559 0.955054
start 2000, delta 50 0.998000 0.976429 0.977667 0.940117 0.956128

IJBB / IJBC detail

VPL 1e-06 1e-05 0.0001 0.001 0.01 0.1 AUC
False, IJBB 0.338948 0.875365 0.941188 0.960467 0.974684 0.983642 0.991774
start 8000, delta 200, IJBB 0.376241 0.8815 0.938559 0.962902 0.976339 0.985881 0.992184
start 2000, delta 50, IJBB 0.353944 0.874002 0.940117 0.961538 0.974684 0.983934 0.991567
False, IJBC 0.848954 0.927954 0.955719 0.972184 0.982462 0.989109 0.994352
start 8000, delta 200, IJBC 0.877895 0.928568 0.955054 0.973513 0.983689 0.990387 0.994527
start 2000, delta 50, IJBC 0.867004 0.926778 0.956128 0.972797 0.982257 0.989211 0.994179

This is the default adjusment now: self.start_iters, self.allowed_delta = 8000 * 128 // batch_size, 200 * 128 // batch_size. I think it do make some sense, especially for IJBB / IJBC 1e-6 accuracy. Also notice start 8000, delta 200 is actually higher in all TAR@FAR just except 1e-4... It may worth a try in some situations.

It's 2 parameters now, vpl_start_iters and vpl_allowed_delta, use_vpl is abandoned. VPL mode is enabled by setting vpl_start_iters > 0 now, like tt = train.Train(..., vpl_start_iters=8000). Default vpl_start_iters=-1, vpl_allowed_delta=200.

Thank you for your work. It's indeed worth it try. And additional question about IJB validation dataset: Did you try use their 1:N test?

I'm using my IJB_evals.py. Just ran a bunch of 1:N tests, VPL start 8000, delta 200 performs not bad in this test:

  • VPL False
    >>>> Gallery 1 top1: 0.972289, top5: 0.979864, top10: 0.980662
    >>>> Gallery 2 top1: 0.933955, top5: 0.958698, top10: 0.966882
    >>>> Mean [Mean] top1: 0.952678, top5: 0.969036, top10: 0.973612
    far g1_tpir g1_thresh g2_tpir g2_thresh mean_tpir
    0.0001 0.188596 0.849763 0.032737 0.923226 0.110667
    0.001 0.361443 0.796886 0.169775 0.843299 0.265609
    0.01 0.919458 0.406492 0.881043 0.37894 0.90025
    0.1 0.962919 0.266365 0.924629 0.264748 0.943774
    1 0.972289 0.12289 0.933955 0.118918 0.953122
  • VPL start 8000, delta 200
    >>>> Gallery 1 top1: 0.972289, top5: 0.981459, top10: 0.982855
    >>>> Gallery 2 top1: 0.937571, top5: 0.962314, top10: 0.967834
    >>>> Mean [Mean] top1: 0.954528, top5: 0.971665, top10: 0.975170
    far g1_tpir g1_thresh g2_tpir g2_thresh mean_tpir
    0.0001 0.226077 0.840018 0.0411115 0.917713 0.133594
    0.001 0.394338 0.788333 0.252569 0.812294 0.323454
    0.01 0.916069 0.407797 0.894366 0.359157 0.905217
    0.1 0.960526 0.270085 0.926532 0.269275 0.943529
    1 0.972289 0.123328 0.937571 0.123652 0.95493
  • VPL start 2000, delta 50
    >>>> Gallery 1 top1: 0.973086, top5: 0.978868, top10: 0.981659
    >>>> Gallery 2 top1: 0.934336, top5: 0.960030, top10: 0.966121
    >>>> Mean [Mean] top1: 0.953262, top5: 0.969231, top10: 0.973710
    far g1_tpir g1_thresh g2_tpir g2_thresh mean_tpir
    0.0001 0.225279 0.839388 0.0371146 0.921194 0.131197
    0.001 0.399522 0.78508 0.186334 0.838523 0.292928
    0.01 0.916467 0.408205 0.863533 0.408381 0.89
    0.1 0.961324 0.263613 0.923297 0.264048 0.94231
    1 0.973086 0.11342 0.934336 0.114038 0.953711