TODO: Modify wrapper design to functional
hyunwoongko opened this issue · comments
Kevin Ko commented
Describe a TODO feature
When multiple parallelizations are overlapped, the wrapper-style design leads to several undesirable results.
Design notes
1. The old design
class TensorParallel:
def __init__(self, model, ...):
self.module = model
self.xxx_for_tp = xxx
class PipelineParallel:
def __init__(self, model, ...):
self.module = model
self.yyy_for_pp = yyy
model = XXXModel.from_pretrained(...)
model = TensorParallel(model)
model = PipelineParallel(model)
2. problems
- 2.1. accecibility
model.module.module.module.xxx_for_tp <--- it's too bad.
model.generate <--- unavailable
model.save_pretrained <--- unavailable
- 2.2. checkpoint
"transformer.0.attn.q_proj.weight" => "module.module.module.transformer.0.attn.q_proj.weight"
3. new design - class like function !
def TensorParallel(model, parallel_context, ...):
# do something
return model