tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

sync batch normalization accross gpus

yonigottesman opened this issue · comments

When using distributed strategies (a must in these models) the simsiam, simclr and barlow twins all use synced batch norm across devices.
simclr uses tf.keras.layers.experimental.SyncBatchNormalization and simsiam,barlow use pytorch nn.SyncBatchNorm.convert_sync_batchnorm.

we should either rewrite the models with SyncBatchNormalization (will have to implement the resnet50) or come up with a tf function convert_sync_batchnorm which replaces the bn layers.

What do you think?
(I would like to work on this issue)

That sounds like something we definitely should support. I think we can add it to the build_<model_name>() functions in the various architectures. We are already searching the tf.keras.applications models for the batch layers there, so we could add another option to find and replace the batch norm layers. Alternatively we could add it as a stand alone utility function that accepts a model and does the find and replace that way.

The only concern I have is I'm not sure what impact this will have if we load the imagenet weights. I've been freezing the batch norm layers in that case, so maybe it doesn't matter here as I'm assuming we're training from scratch in the distributed case.

WDYT?

Im not sure how easy it is to search and replace a layer in tf. It can be done, but maybe tricky.
Trivially just running on the layers like this can be hard with skip connections.

x = model.input
for l in layers:
    if l == batch_norm:
        x = sync_batch_norm(x)
    else:
        x = l(x)

Ill give it a try but If it wont work I might need to add the code for resnet50 like the resnet18 and not take from keras.applications.

Regarding the imagnet weights this is indeed an issue. loading all the weights and then restarting the bn layers is a bad idea. It should be explicit that you cannot use syncedbn and load imagenet weights.
MAYBE it is possible to load the weights from regular bn into synced version this will be the best ill look into this too.

Also, can you re-assign me I deleted myself from the issue by mistake :-)

I added you back as the issue assignee.

I agree, it might be tricky to do the search and replace. On the other hand, it would be good if we could avoid duplicating the tf.keras.applications if possible. We didn't have a choice for ResNet18, but it would be good to use the the applications models if possible.

I also like you approach regarding the the imagenet weights. I think we can just make it so that synced bn is only trainable from scratch.