Extracting Backbone Intermediates for Use in a FPN
zjniu opened this issue · comments
Description of the model to be implemented
I am trying to build a Feature Pyramid Network (FPN) using an EfficientNetV2 or ResNet backbone. I have already written or found code for these two backbones, implemented as nn.Module. I would like to use intermediate feature maps obtained through the bottom-up pathway (the backbone) for use in the top-down pathway to extract features. I see that the sow
method is recommended as a way to extract intermediates, but that requires initializing and applying the nn.Module. I believe this is not possible to do inside another module, in this case a FPN module that I would like to build. How can I extract intermediates from one module within the __call__
of another module?
Specific points to consider
There are two options that I am considering. One way is to create a FPN module wrapper that requires a backbone
parameter, which will run the EfficientNetV2 or ResNet backbone and extract their intermediates. Another option is to keep the backbone and FPN modules separate, where the FPN module would take intermediates of the backbone modules as input. The second option seems more plausible using the sow
method, but I am not sure how training would work when these two modules are strung together in sequence.
Reference implementations in other frameworks
In this FPN implementation with Keras (https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/models/fpn.py), the get_layer
method in Keras is used to obtain intermediate outputs. Is there something in Flax that can be used to achieve the same task?