[Core] Design discussion: Refactor component/backend organisation
michaelschaarschmidt opened this issue · comments
The current implementations grew out from experimental design around multi-backend support. The get_backend() checks are undesirable and do not make for readable implementations.
This issue is meant to collect design improvements.
Proposal 1: Components will be reorganised into a base component and backend-specific sub-classes.
A package tensorflow_components/pytorch_components will mirror the folder structure of the base components and contain the specific implementations.
Advantages: Avoid backend checks in implementations, clearly separates backends from interfaces
Disadvantages: Multiplies number of components, potentially irritating to see mirrored folder structure of
components/
memories/base_memory
tf_components/memories/tf_memory
pytorch_components/memories/pytorch_memory
versus keeping everything in one folder memories (which would make imports more difficult for the package:
components/
memories/
base_memory
tf_memory
pytorch_memory
A possible solution avoiding all extra backend-prefixed classes (such as TfMemory
) would be to use dynamic imports. So the folder structure would be something like the first one suggested above:
E.g.
components/
memories/
memory.py (<- some abstract base class, no graph_fns)
prioritized_replay.py (<- will dynamically import graph_fns located in `tf(or pytorch)/prioritized_replay.py`)
tf/
prioritized_replay.py (<- contains only tf graph_fns, no classes)
pytorch/
prioritized_replay.py (<- contains only pytorch graph_fns, no classes)
We would then only need one extra line per backend-agnostic Component-class file (e.g. prioritized_replay.py in the main memories
directory):
# This is a dynamic import (dependent on the backend) of a graph-fn-containing python module.
_BACKEND_MOD = importlib.import_module(
"rlcore.components.memories."+get_backend()+".prioritized_replay"
)
All graph_fns inside the main class file would reduce to mere backend-independent stubs:
@graph_fn
def _graph_fn_insert(self, records):
return _BACKEND_MOD.insert_function(records)
The actual backend-dependent implementation of insert_function
would be in the tf or pytorch folders.
This way, no import statements would have to be conditional anymore in any files. And we would not need any backend-specific sub-classing, which could potentially mess up the class hierarchy.
Counterpoint:
- The above would not solve variable creation which is backend-dependent, non-static / relies on inheritance, and would still need backend checks in prioritized_replay.py in this case. Otherwise one would have to significantly refactor all get variable methods.
- Readability (Personal opinion): Would have to jump between graph fns and stubs instead of reading all relevant code in a single class which would be my expectation. Clicking through a call chain would lead to a stub, then would have to click again to reach the graph fn.
- Performance: Would double decorator call chains for eager/define-by-run mode because we would need to do an extra jump..