Pointcept / Pointcept

Pointcept: a codebase for point cloud perception research. Latest works: PTv3 (CVPR'24 Oral), PPT (CVPR'24), OA-CNNs (CVPR'24), MSC (CVPR'23)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Fine-Tuning Point Transformer PT v3

RauchLukas opened this issue · comments

I have a question regarding the possibilities to fine-tune a pre.-trained model with my own dataset.

From the Swin3D example, I figured out that it is possible to load a pretrained model without the classifier layer. This way I can train for my own dataset with X classes in a transfer learning approach.

This works for me as expected with the SWIN3D backbone by configuring the hooks in the config.file.:

... 

hooks = [
    dict(type="CheckpointLoader", unload_keywords="backbone.classifier"),
    dict(type="IterationTimer", warmup_iter=2),
    dict(type="InformationWriter"),
    dict(type="SemSegEvaluator"),
    dict(type="CheckpointSaver", save_freq=None),
    dict(type="PreciseEvaluator", test_last=False),
]

Nothing new until here, but:

If I Try to unload the classification layer of a pretrained PT-v3 Model

I get an error of shape miss-match:

-- Process 4 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/workspace/Pointcept/pointcept/engines/launch.py", line 137, in _distributed_worker
    main_func(*cfg)
  File "/workspace/Pointcept/tools/train.py", line 51, in main_worker
    trainer.train()
  File "/workspace/Pointcept/pointcept/engines/train.py", line 153, in train
    self.before_train()
  File "/workspace/Pointcept/pointcept/engines/train.py", line 85, in before_train
    h.before_train()
  File "/workspace/Pointcept/pointcept/engines/hooks/misc.py", line 259, in before_train
    load_state_info = self.trainer.model.load_state_dict(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DistributedDataParallel:
        size mismatch for module.seg_head.weight: copying a param with shape torch.Size([25, 64]) from checkpoint, the shape in current model is torch.Size([11, 64]).
        size mismatch for module.seg_head.bias: copying a param with shape torch.Size([25]) from checkpoint, the shape in current model is torch.Size([11]).
        size mismatch for module.backbone.embedding.stem.conv.weight: copying a param with shape torch.Size([32, 5, 5, 5, 3]) from checkpoint, the shape in current model is torch.Size([32, 5, 5, 5, 6]).

I pretrained with the Structured3D dataset (25 classes) and tried to fine-tune with my dataset (11 classes)
The error is a miss-match in the: module.seg_head.weight, module.seg_head.bias, and module.backbone.embedding.stem.conv.weight.

This leads to my assumption that the number of classification classes plays a role not only in the classifier, in case of the PT-v3.

Is there an option to fine tune a PT-v3 Backbone in the same way as in the SWIN3D backbone example?

very grateful for your help!

cheers, Lukas

Hi Lukas, one most direct solution is to overload the following config:

hooks = [
    dict(type="CheckpointLoader", keywords="module.seg_head.", replacement="module.seg_head_duplicate."),
    dict(type="IterationTimer", warmup_iter=2),
    dict(type="InformationWriter"),
    dict(type="SemSegEvaluator"),
    dict(type="CheckpointSaver", save_freq=None),
    dict(type="PreciseEvaluator", test_last=False),
]

Replacing the old seg_head with a dummy name to prevent loading the weight.


A more elegant way is to modify our CheckpointLoader in https://github.com/Pointcept/Pointcept-Dev/blob/v1.5.2_dev/pointcept/engines/hooks/misc.py#L207, make the unstrict load model become more unstrict (if the weight shape doesn't match, prevent load the weight)