Fine-Tuning Point Transformer PT v3
RauchLukas opened this issue · comments
I have a question regarding the possibilities to fine-tune a pre.-trained model with my own dataset.
From the Swin3D example, I figured out that it is possible to load a pretrained model without the classifier
layer. This way I can train for my own dataset with X classes in a transfer learning approach.
This works for me as expected with the SWIN3D
backbone by configuring the hooks in the config.file.:
...
hooks = [
dict(type="CheckpointLoader", unload_keywords="backbone.classifier"),
dict(type="IterationTimer", warmup_iter=2),
dict(type="InformationWriter"),
dict(type="SemSegEvaluator"),
dict(type="CheckpointSaver", save_freq=None),
dict(type="PreciseEvaluator", test_last=False),
]
Nothing new until here, but:
If I Try to unload the classification layer of a pretrained PT-v3 Model
I get an error of shape miss-match:
-- Process 4 terminated with the following error:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/workspace/Pointcept/pointcept/engines/launch.py", line 137, in _distributed_worker
main_func(*cfg)
File "/workspace/Pointcept/tools/train.py", line 51, in main_worker
trainer.train()
File "/workspace/Pointcept/pointcept/engines/train.py", line 153, in train
self.before_train()
File "/workspace/Pointcept/pointcept/engines/train.py", line 85, in before_train
h.before_train()
File "/workspace/Pointcept/pointcept/engines/hooks/misc.py", line 259, in before_train
load_state_info = self.trainer.model.load_state_dict(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DistributedDataParallel:
size mismatch for module.seg_head.weight: copying a param with shape torch.Size([25, 64]) from checkpoint, the shape in current model is torch.Size([11, 64]).
size mismatch for module.seg_head.bias: copying a param with shape torch.Size([25]) from checkpoint, the shape in current model is torch.Size([11]).
size mismatch for module.backbone.embedding.stem.conv.weight: copying a param with shape torch.Size([32, 5, 5, 5, 3]) from checkpoint, the shape in current model is torch.Size([32, 5, 5, 5, 6]).
I pretrained with the Structured3D dataset (25 classes) and tried to fine-tune with my dataset (11 classes)
The error is a miss-match in the: module.seg_head.weight
, module.seg_head.bias
, and module.backbone.embedding.stem.conv.weight
.
This leads to my assumption that the number of classification classes plays a role not only in the classifier, in case of the PT-v3.
Is there an option to fine tune a PT-v3 Backbone in the same way as in the SWIN3D backbone example?
very grateful for your help!
cheers, Lukas
Hi Lukas, one most direct solution is to overload the following config:
hooks = [
dict(type="CheckpointLoader", keywords="module.seg_head.", replacement="module.seg_head_duplicate."),
dict(type="IterationTimer", warmup_iter=2),
dict(type="InformationWriter"),
dict(type="SemSegEvaluator"),
dict(type="CheckpointSaver", save_freq=None),
dict(type="PreciseEvaluator", test_last=False),
]
Replacing the old seg_head with a dummy name to prevent loading the weight.
A more elegant way is to modify our CheckpointLoader in https://github.com/Pointcept/Pointcept-Dev/blob/v1.5.2_dev/pointcept/engines/hooks/misc.py#L207, make the unstrict load model become more unstrict (if the weight shape doesn't match, prevent load the weight)