zzdxjtu / nonlocal-tf

TensorFlow implementation of the Non-local Neural Network block in its various forms.

Home Page:https://arxiv.org/pdf/1711.07971.pdf

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Non-local Neural Networks in TensorFlow

This is a TensorFlow (no Keras) implementation of the building blocks described in Non-local Neural Networks by Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He. It can simply be dropped into any existing model, and is compatible with TensorFlow's pre-trained ResNet models.

Usage

The core code for the block is located in the nonlocal_resnet_utils.py file, you can just drop it into your code and use it as-is. Usage is described in the heredoc comment and should be straightforward.

An example of a ResNet50, dramatically simplified from TensorFlow's "official" implementation, can be found in nonlocal_resnet_v1_50_nl3.py, again with usage described in the heredoc.

The nice thing is that simplified implementation is compatible with the ImageNet pre-trained weights released by Google. So overall, to create a ResNet50 with one non-local block at stage 3, and loading pre-trained ImageNet weights, it's as simple as:

import nonlocal_resnet_v1_50_nl3 as model
endpoints, body_prefix = model.endpoints(images, is_training=True)

# BEFORE DEFINING THE OPTIMIZER (because it creates new global variables):

model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)

# IF LOADING PRE-TRAINED WEIGHTS (inside your with Session... block, in the first iteration):

saver = tf.train.Saver(model_variables)
saver.restore(sess, args.initial_checkpoint)

# Do something with `endpoints['model_output']`

Note that the last endpoint is the result of the global average pooling operation, so this code does not include the final 1000-classes classification layer. If you are not doing ImageNet classification (in which case: why use this??), you typically don't want that one anyways.

About

TensorFlow implementation of the Non-local Neural Network block in its various forms.

https://arxiv.org/pdf/1711.07971.pdf

License:MIT License


Languages

Language:Python 100.0%