google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Extracting Backbone Intermediates for Use in a FPN

zjniu opened this issue · comments

Description of the model to be implemented

I am trying to build a Feature Pyramid Network (FPN) using an EfficientNetV2 or ResNet backbone. I have already written or found code for these two backbones, implemented as nn.Module. I would like to use intermediate feature maps obtained through the bottom-up pathway (the backbone) for use in the top-down pathway to extract features. I see that the sow method is recommended as a way to extract intermediates, but that requires initializing and applying the nn.Module. I believe this is not possible to do inside another module, in this case a FPN module that I would like to build. How can I extract intermediates from one module within the __call__ of another module?

Specific points to consider

There are two options that I am considering. One way is to create a FPN module wrapper that requires a backbone parameter, which will run the EfficientNetV2 or ResNet backbone and extract their intermediates. Another option is to keep the backbone and FPN modules separate, where the FPN module would take intermediates of the backbone modules as input. The second option seems more plausible using the sow method, but I am not sure how training would work when these two modules are strung together in sequence.

Reference implementations in other frameworks

In this FPN implementation with Keras (https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/models/fpn.py), the get_layer method in Keras is used to obtain intermediate outputs. Is there something in Flax that can be used to achieve the same task?