[Refactoring] even tighter integration with Lightning
fedebotu opened this issue · comments
As suggested by @Junyoungpark , we could reduce the complexity by having the RL model directly handled by PyTorch Lightning, which is what happens in the scheme below:
At the moment, we have the following levels:
└── RL4COLitModule <- PyTorch Lightning Module
└──Model <- e.g. `AttentionModel` (=RL)
└── Policy <- e.g. `AttentionModelPolicy`
However, this could be simplified as
└── RLModel <- PyTorch Lightning Module (=RL, e.g. REINFORCE)
└── Policy <- e.g. `AttentionModel`
This would also allow for easier implementation of e.g. PPO, since the inner optimization loop would be done directly in PyTorch Lightning. Moreover, we would not need to have callbacks in RL4COLitModule
to the models themselves / baselines, since everything would be integrated into a single module 🚀
Done with new release 🚀