VPL - Variational Prototype Learning for deep face recognition
abdikaiym01 opened this issue · comments
Do you have a plan to implement their(insightface) last work? It seems does not work as they claimed in their papers. deepinsight/insightface#1801
Is it the current SOTA for the face recognition task?
I think an imlementation is possible, will try and check if will be better.
VPL mode is added. It can be enabled by tt = train.Train(..., use_vpl=True)
. I think it should be same with the official implementation. Here is my test results using basic EfficientNetV2B0 + AdamW
:
import losses, train, models
import tensorflow_addons as tfa
keras.mixed_precision.set_global_policy("mixed_float16")
data_basic_path = '/datasets/ms1m-retinaface-t1'
data_path = data_basic_path + '_112x112_folders'
eval_paths = [os.path.join(data_basic_path, ii) for ii in ['lfw.bin', 'cfp_fp.bin', 'agedb_30.bin']]
from keras_cv_attention_models import efficientnet
basic_model = efficientnet.EfficientNetV2B0(input_shape=(112, 112, 3), num_classes=0)
basic_model = models.buildin_models(basic_model, dropout=0, emb_shape=512, output_layer='GDC', bn_epsilon=1e-4, bn_momentum=0.9, scale=True, use_bias=False)
tt = train.Train(data_path, eval_paths=eval_paths,
save_path='TT_efv2_b0_swish_GDC_arc_emb512_dr0_adamw_5e4_bs512_ms1mv3_randaug_cos16_batch_float16_vpl.h5',
basic_model=basic_model, model=None, lr_base=0.01, lr_decay=0.5, lr_decay_steps=16, lr_min=1e-6, lr_warmup_steps=3,
batch_size=512, random_status=100, eval_freq=4000, output_weight_decay=1, use_vpl=True)
import tensorflow_addons as tfa
optimizer = tfa.optimizers.AdamW(learning_rate=1e-2, weight_decay=5e-4, exclude_from_weight_decay=["/gamma", "/beta"])
sch = [
{"loss": losses.ArcfaceLoss(scale=16), "epoch": 4, "optimizer": optimizer},
{"loss": losses.ArcfaceLoss(scale=32), "epoch": 3},
{"loss": losses.ArcfaceLoss(scale=64), "epoch": 46},
]
tt.train(sch, 0)
exit()
VPL | lfw | cfp_fp | agedb_30 | IJBB 1e-4 | IJBC 1e-4 |
---|---|---|---|---|---|
False | 0.997667 | 0.979429 | 0.978333 | 0.941188 | 0.955719 |
True | 0.997667 | 0.979571 | 0.978500 | 0.938559 | 0.955054 |
In the vpl paper the results were absolutely different. It turns out the reality quite another. What do you think about it? Maybe your implementation a little different than their, though I'm not sure about it?
Ya, I have compared them several times. It seems the main parts are:
- insightface/prepare_queue_lambda
Here it's models.py#L247
def prepare_queue_lambda(self, label, iters): self.queue_lambda[:] = 0.0 if iters>self.cfg['start_iters']: allowed_delta = self.cfg['allowed_delta'] if self.vpl_mode==0: past_iters = iters - self.queue_iters idx = torch.where(past_iters <= allowed_delta)[0] self.queue_lambda[idx] = self.cfg['lambda']
queue_lambda = tf.cond( self.iters > self.start_iters, lambda: tf.where(self.iters - self.queue_iters <= self.allowed_delta, self.vpl_lambda, 0.0), # prepare_queue_lambda lambda: self.zero_queue_lambda, )
- insightface vpl_mode>=0
Here models.py#L254
... _lambda = self.queue_lambda.view(self.num_local, 1) injected_weight = norm_weight*(1.0-_lambda) + self.queue*_lambda injected_norm_weight = normalize(injected_weight) ...
norm_w = K.l2_normalize(self.w, axis=0) injected_weight = norm_w * (1 - queue_lambda) + tf.transpose(self.queue_features) * queue_lambda injected_norm_weight = K.l2_normalize(injected_weight, axis=0)
- insightface/set_queue
Here in myCallbacks/VPLUpdateQueue
def set_queue(self, total_features, total_label, index_positive, iters): local_label = total_label[index_positive] sel_features = normalize(total_features[index_positive,:]) self.queue[local_label,:] = sel_features self.queue_iters[local_label] = iters
As it needs the true labels, have to do this update outside of model. I think the logic is matching with official one, or if I'm missing something...class VPLUpdateQueue(keras.callbacks.Callback): def __init__(self): super().__init__() def on_batch_end(self, batch, logs=None): batch_labels_back_up = self.model.loss[0].batch_labels_back_up update_label_pos = tf.expand_dims(batch_labels_back_up, 1) vpl_norm_dense_layer = self.model.layers[-1] updated_queue = tf.tensor_scatter_nd_update(vpl_norm_dense_layer.queue_features, update_label_pos, vpl_norm_dense_layer.norm_features) vpl_norm_dense_layer.queue_features.assign(updated_queue) iters = tf.repeat(vpl_norm_dense_layer.iters, tf.shape(batch_labels_back_up)[0]) updated_queue_iters = tf.tensor_scatter_nd_update(vpl_norm_dense_layer.queue_iters, update_label_pos, iters) vpl_norm_dense_layer.queue_iters.assign(updated_queue_iters)
- In the official example_ms1mv3.py config, they are using
'start_iters': 8000, 'allowed_delta': 200
forbatch_size = 128
. I'm using that forbatch_size=512
, may trystart_iters=8000 / 4, allowed_delta=200 / 4
forbatch_size=512
later.
Here is the result using start_iters=8000 / 4, allowed_delta=200 / 4
for batch_size=512
:
Results
VPL | lfw | cfp_fp | agedb_30 | IJBB 1e-4 | IJBC 1e-4 |
---|---|---|---|---|---|
False | 0.997667 | 0.979429 | 0.978333 | 0.941188 | 0.955719 |
start 8000, delta 200 | 0.997667 | 0.979571 | 0.978500 | 0.938559 | 0.955054 |
start 2000, delta 50 | 0.998000 | 0.976429 | 0.977667 | 0.940117 | 0.956128 |
IJBB / IJBC detail
VPL | 1e-06 | 1e-05 | 0.0001 | 0.001 | 0.01 | 0.1 | AUC |
---|---|---|---|---|---|---|---|
False, IJBB | 0.338948 | 0.875365 | 0.941188 | 0.960467 | 0.974684 | 0.983642 | 0.991774 |
start 8000, delta 200, IJBB | 0.376241 | 0.8815 | 0.938559 | 0.962902 | 0.976339 | 0.985881 | 0.992184 |
start 2000, delta 50, IJBB | 0.353944 | 0.874002 | 0.940117 | 0.961538 | 0.974684 | 0.983934 | 0.991567 |
False, IJBC | 0.848954 | 0.927954 | 0.955719 | 0.972184 | 0.982462 | 0.989109 | 0.994352 |
start 8000, delta 200, IJBC | 0.877895 | 0.928568 | 0.955054 | 0.973513 | 0.983689 | 0.990387 | 0.994527 |
start 2000, delta 50, IJBC | 0.867004 | 0.926778 | 0.956128 | 0.972797 | 0.982257 | 0.989211 | 0.994179 |
This is the default adjusment now: . I think it do make some sense, especially for self.start_iters, self.allowed_delta = 8000 * 128 // batch_size, 200 * 128 // batch_size
IJBB / IJBC 1e-6
accuracy. Also notice start 8000, delta 200
is actually higher in all TAR@FAR
just except 1e-4
... It may worth a try in some situations.
It's 2 parameters now, vpl_start_iters
and vpl_allowed_delta
, use_vpl
is abandoned. VPL mode is enabled by setting vpl_start_iters > 0
now, like tt = train.Train(..., vpl_start_iters=8000)
. Default vpl_start_iters=-1, vpl_allowed_delta=200
.
Thank you for your work. It's indeed worth it try. And additional question about IJB validation dataset: Did you try use their 1:N test?
I'm using my IJB_evals.py
. Just ran a bunch of 1:N
tests, VPL start 8000, delta 200
performs not bad in this test:
- VPL False
>>>> Gallery 1 top1: 0.972289, top5: 0.979864, top10: 0.980662 >>>> Gallery 2 top1: 0.933955, top5: 0.958698, top10: 0.966882 >>>> Mean [Mean] top1: 0.952678, top5: 0.969036, top10: 0.973612
far g1_tpir g1_thresh g2_tpir g2_thresh mean_tpir 0.0001 0.188596 0.849763 0.032737 0.923226 0.110667 0.001 0.361443 0.796886 0.169775 0.843299 0.265609 0.01 0.919458 0.406492 0.881043 0.37894 0.90025 0.1 0.962919 0.266365 0.924629 0.264748 0.943774 1 0.972289 0.12289 0.933955 0.118918 0.953122 - VPL start 8000, delta 200
>>>> Gallery 1 top1: 0.972289, top5: 0.981459, top10: 0.982855 >>>> Gallery 2 top1: 0.937571, top5: 0.962314, top10: 0.967834 >>>> Mean [Mean] top1: 0.954528, top5: 0.971665, top10: 0.975170
far g1_tpir g1_thresh g2_tpir g2_thresh mean_tpir 0.0001 0.226077 0.840018 0.0411115 0.917713 0.133594 0.001 0.394338 0.788333 0.252569 0.812294 0.323454 0.01 0.916069 0.407797 0.894366 0.359157 0.905217 0.1 0.960526 0.270085 0.926532 0.269275 0.943529 1 0.972289 0.123328 0.937571 0.123652 0.95493 - VPL start 2000, delta 50
>>>> Gallery 1 top1: 0.973086, top5: 0.978868, top10: 0.981659 >>>> Gallery 2 top1: 0.934336, top5: 0.960030, top10: 0.966121 >>>> Mean [Mean] top1: 0.953262, top5: 0.969231, top10: 0.973710
far g1_tpir g1_thresh g2_tpir g2_thresh mean_tpir 0.0001 0.225279 0.839388 0.0371146 0.921194 0.131197 0.001 0.399522 0.78508 0.186334 0.838523 0.292928 0.01 0.916467 0.408205 0.863533 0.408381 0.89 0.1 0.961324 0.263613 0.923297 0.264048 0.94231 1 0.973086 0.11342 0.934336 0.114038 0.953711