dvlab-research / SphereFormer

The official implementation for "Spherical Transformer for LiDAR-based 3D Recognition" (CVPR 2023).

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Size dismatch when load your provided pre-trained model

Qizhi697 opened this issue · comments

RuntimeError: Error(s) in loading state_dict for DistributedDataParallel:
        size mismatch for module.unet.conv.2.weight: copying a param with shape torch.Size([2, 2, 2, 64, 32]) from checkpoint, the shape in current model is torch.Size([64, 2, 2, 2, 32]).
        size mismatch for module.unet.u.conv.2.weight: copying a param with shape torch.Size([2, 2, 2, 128, 64]) from checkpoint, the shape in current model is torch.Size([128, 2, 2, 2, 64]).
        size mismatch for module.unet.u.u.conv.2.weight: copying a param with shape torch.Size([2, 2, 2, 256, 128]) from checkpoint, the shape in current model is torch.Size([256, 2, 2, 2, 128]).
        size mismatch for module.unet.u.u.u.conv.2.weight: copying a param with shape torch.Size([2, 2, 2, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 2, 2, 2, 256]).
        size mismatch for module.unet.u.u.u.deconv.2.weight: copying a param with shape torch.Size([2, 2, 2, 256, 256]) from checkpoint, the shape in current model is torch.Size([256, 2, 2, 2, 256]).
        size mismatch for module.unet.u.u.deconv.2.weight: copying a param with shape torch.Size([2, 2, 2, 128, 256]) from checkpoint, the shape in current model is torch.Size([128, 2, 2, 2, 256]).
        size mismatch for module.unet.u.deconv.2.weight: copying a param with shape torch.Size([2, 2, 2, 64, 128]) from checkpoint, the shape in current model is torch.Size([64, 2, 2, 2, 128]).
        size mismatch for module.unet.deconv.2.weight: copying a param with shape torch.Size([2, 2, 2, 32, 64]) from checkpoint, the shape in current model is torch.Size([32, 2, 2, 2, 64]).

The training process is successful and I can load my own trained model to do evaluation.

Here is my semantic_kitti_unet32_spherical_transformer.yaml

DATA:
  data_name: semantic_kitti
  data_root: /data/dataset/SemanticKITTI/dataset/
  label_mapping: util/semantic-kitti.yaml
  classes: 19
  fea_dim: 6
  voxel_size: [0.05, 0.05, 0.05]
  voxel_max: 120000 

TRAIN:
  # arch
  arch: unet_spherical_transformer
  input_c: 4
  m: 32
  block_reps: 2
  block_residual: True
  layers: [32, 64, 128, 256, 256]
  quant_size_scale: 24
  patch_size: 1 
  window_size: 6
  use_xyz: True
  sync_bn: True  # adopt sync_bn or not
  rel_query: True
  rel_key: True
  rel_value: True
  drop_path_rate: 0.3
  max_batch_points: 1000000
  class_weight: [ 3.1557,  8.7029,  7.8281,  6.1354,  6.3161,  7.9937,  8.9704,
                          10.1922,  1.6155,  4.2187,  1.9385,  5.5455,  2.0198,  2.6261,  1.3212,
                          5.1102,  2.5492,  5.8585,  7.3929]
  xyz_norm: False
  pc_range: [[-51.2, -51.2, -4], [51.2, 51.2, 2.4]]
  window_size_sphere: [2, 2, 80]
  window_size_scale: [2.0, 1.5]
  sphere_layers: [1,2,3,4,5]
  grad_checkpoint_layers: []
  a: 0.0125
  loss_name: ce_loss
  use_tta: False
  vote_num: 4

  # training
  aug: True
  transformer_lr_scale: 0.1 
  scheduler_update: step 
  scheduler: Poly

  power: 0.9
  use_amp: True
  train_gpu: [0,1] 
  workers: 32  # data loader workers 
  batch_size: 4 # batch size for training
  batch_size_val: 4 # batch size for validation during training, memory and speed tradeoff
  base_lr: 0.006 
  epochs: 100 
  start_epoch: 0
  momentum: 0.9
  weight_decay: 0.02 
  drop_rate: 0.5

  ignore_label: 255
  manual_seed: 123
  print_freq: 10
  save_freq: 1
  save_path: runs/semantic_kitti_unet32_spherical_transformer
  weight: model/model_semantic_kitti.pth
  resume: 
  evaluate: True  # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend
  eval_freq: 1
  val: True
  
Distributed:
  dist_url: tcp://127.0.0.1:6789
  dist_backend: 'nccl'
  multiprocessing_distributed: True
  world_size: 1
  rank: 0

Hi,

This issue seems to be caused by version inconsistency of spconv. Please change another version of it, and try again.

Thanks for your reply.
I checked my spconv version, which I installed is spconv-cu118 but it doesn't has the version 2.1.21

Could you try to install spconv-cu114?