Pointcept / Pointcept

Pointcept: a codebase for point cloud perception research. Latest works: PTv3 (CVPR'24 Oral), PPT (CVPR'24), OA-CNNs (CVPR'24), MSC (CVPR'23)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

An error was encountered when training with a custom s3dis dataset: channel size mismatch

jdjiang312 opened this issue · comments

Hello. I have a problem and need your kind help: When I was training with a custom s3dis data set with x, y, z, and intensity, the following error occurred. However, I have changed in_channels to 4 in point_transformer_v3m1_base.py, but the error still occurs. I hope to get your reply, thank you in advance for your time and answers.

Traceback (most recent call last):
File "tools/train.py", line 40, in
main()
File "tools/train.py", line 29, in main
launch(
File "/home/jiang/桌面/Pointcept/pointcept/engines/launch.py", line 95, in launch
main_func(*cfg)
File "tools/train.py", line 21, in main_worker
trainer.train() # 执行训练
File "/home/jiang/桌面/Pointcept/pointcept/engines/train.py", line 183, in train
self.run_step()
File "/home/jiang/桌面/Pointcept/pointcept/engines/train.py", line 197, in run_step
output_dict = self.model(input_dict)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jiang/桌面/Pointcept/pointcept/models/default.py", line 54, in forward
point = self.backbone(point)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jiang/桌面/Pointcept/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py", line 704, in forward
point = self.embedding(point)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jiang/桌面/Pointcept/pointcept/models/point_transformer_v3/point_transformer_v3m1_base.py", line 514, in forward
point = self.stem(point)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jiang/桌面/Pointcept/pointcept/models/modules.py", line 66, in forward
input.sparse_conv_feat = module(input.sparse_conv_feat)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/spconv/pytorch/conv.py", line 755, in forward
return self._conv_forward(self.training,
File "/home/jiang/miniconda3/envs/PTv3/lib/python3.8/site-packages/spconv/pytorch/conv.py", line 169, in _conv_forward
assert input.features.shape[
AssertionError: channel size mismatch

Here is part of the code after I modified in_channels:
class PointTransformerV3(PointModule):
def init(
self,
in_channels=4, ###only have intensity imfomation
order=("z", "z_trans"),
stride=(2, 2, 2, 2),
enc_depths=(2, 2, 2, 6, 2),
enc_channels=(32, 64, 128, 256, 512),
enc_num_head=(2, 4, 8, 16, 32),
enc_patch_size=(48, 48, 48, 48, 48),
dec_depths=(2, 2, 2, 2),
dec_channels=(64, 64, 128, 256),
dec_num_head=(4, 4, 8, 16),
dec_patch_size=(48, 48, 48, 48),
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.3,
pre_norm=True,
shuffle_orders=True,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
cls_mode=False,
pdnorm_bn=False,
pdnorm_ln=False,
pdnorm_decouple=True,
pdnorm_adaptive=False,
pdnorm_affine=True,
pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"),
):
super().init()
self.num_stages = len(enc_depths)
self.order = [order] if isinstance(order, str) else order
self.cls_mode = cls_mode
self.shuffle_orders = shuffle_orders

    assert self.num_stages == len(stride) + 1
    assert self.num_stages == len(enc_depths)
    assert self.num_stages == len(enc_channels)
    assert self.num_stages == len(enc_num_head)
    assert self.num_stages == len(enc_patch_size)
    assert self.cls_mode or self.num_stages == len(dec_depths) + 1
    assert self.cls_mode or self.num_stages == len(dec_channels) + 1
    assert self.cls_mode or self.num_stages == len(dec_num_head) + 1
    assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1

    # norm layers
    if pdnorm_bn:
        bn_layer = partial(
            PDNorm,
            norm_layer=partial(
                nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine
            ),
            conditions=pdnorm_conditions,
            decouple=pdnorm_decouple,
            adaptive=pdnorm_adaptive,
        )
    else:
        bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
    if pdnorm_ln:
        ln_layer = partial(
            PDNorm,
            norm_layer=partial(nn.LayerNorm, elementwise_affine=pdnorm_affine),
            conditions=pdnorm_conditions,
            decouple=pdnorm_decouple,
            adaptive=pdnorm_adaptive,
        )
    else:
        ln_layer = nn.LayerNorm
    # activation layers
    act_layer = nn.GELU

    self.embedding = Embedding(
        in_channels=in_channels,
        embed_channels=enc_channels[0],
        norm_layer=bn_layer,
        act_layer=act_layer,
    )

    # encoder
    enc_drop_path = [
        x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
    ]
    self.enc = PointSequential()
    for s in range(self.num_stages):
        enc_drop_path_ = enc_drop_path[
            sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
        ]
        enc = PointSequential()
        if s > 0:
            enc.add(
                SerializedPooling(
                    in_channels=enc_channels[s - 1],
                    out_channels=enc_channels[s],
                    stride=stride[s - 1],
                    norm_layer=bn_layer,
                    act_layer=act_layer,
                ),
                name="down",
            )
        for i in range(enc_depths[s]):
            enc.add(
                Block(
                    channels=enc_channels[s],
                    num_heads=enc_num_head[s],
                    patch_size=enc_patch_size[s],
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    attn_drop=attn_drop,
                    proj_drop=proj_drop,
                    drop_path=enc_drop_path_[i],
                    norm_layer=ln_layer,
                    act_layer=act_layer,
                    pre_norm=pre_norm,
                    order_index=i % len(self.order),
                    cpe_indice_key=f"stage{s}",
                    enable_rpe=enable_rpe,
                    enable_flash=enable_flash,
                    upcast_attention=upcast_attention,
                    upcast_softmax=upcast_softmax,
                ),
                name=f"block{i}",
            )
        if len(enc) != 0:
            self.enc.add(module=enc, name=f"enc{s}")

    # decoder
    if not self.cls_mode:
        dec_drop_path = [
            x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
        ]
        self.dec = PointSequential()
        dec_channels = list(dec_channels) + [enc_channels[-1]]
        for s in reversed(range(self.num_stages - 1)):
            dec_drop_path_ = dec_drop_path[
                sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
            ]
            dec_drop_path_.reverse()
            dec = PointSequential()
            dec.add(
                SerializedUnpooling(
                    in_channels=dec_channels[s + 1],
                    skip_channels=enc_channels[s],
                    out_channels=dec_channels[s],
                    norm_layer=bn_layer,
                    act_layer=act_layer,
                ),
                name="up",
            )
            for i in range(dec_depths[s]):
                dec.add(
                    Block(
                        channels=dec_channels[s],
                        num_heads=dec_num_head[s],
                        patch_size=dec_patch_size[s],
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        attn_drop=attn_drop,
                        proj_drop=proj_drop,
                        drop_path=dec_drop_path_[i],
                        norm_layer=ln_layer,
                        act_layer=act_layer,
                        pre_norm=pre_norm,
                        order_index=i % len(self.order),
                        cpe_indice_key=f"stage{s}",
                        enable_rpe=enable_rpe,
                        enable_flash=enable_flash,
                        upcast_attention=upcast_attention,
                        upcast_softmax=upcast_softmax,
                    ),
                    name=f"block{i}",
                )
            self.dec.add(module=dec, name=f"dec{s}")

def forward(self, data_dict):
    point = Point(data_dict)
    point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
    point.sparsify()

    point = self.embedding(point)
    point = self.enc(point)
    if not self.cls_mode:
        point = self.dec(point)
    # else:
    #     point.feat = torch_scatter.segment_csr(
    #         src=point.feat,
    #         indptr=nn.functional.pad(point.offset, (1, 0)),
    #         reduce="mean",
    #     )
    return point

Hi, could you provide a detailed config for further judgment? Also, may I confirm that you directly concate intensity to "coord" in our codebase? (If so, please refer our configs for outdoor dataset and use "strength" to mark intensity value)

Hi, I'm very happy to see your timely recovery! Here is my detailed config. And i concate intensity to "color" in your codebase, which means the "color" information is my intensity information.
"""
Point Transformer - V3 Mode1

Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""

from functools import partial
from addict import Dict
import math
import torch
import torch.nn as nn
import spconv.pytorch as spconv
import torch_scatter
from timm.models.layers import DropPath

try:
import flash_attn
except ImportError:
flash_attn = None

from pointcept.models.point_prompt_training import PDNorm
from pointcept.models.builder import MODELS
from pointcept.models.utils.misc import offset2bincount
from pointcept.models.utils.structure import Point
from pointcept.models.modules import PointModule, PointSequential

class RPE(torch.nn.Module): # relative position encoding
def init(self, patch_size, num_heads): # 初始化方法,接受两个参数 patch_size 和 num_heads。
super().init()
self.patch_size = patch_size
self.num_heads = num_heads
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) # 计算位置编码的边界,通过对输入 patch_size 进行一些数学计算得到。
self.rpe_num = 2 * self.pos_bnd + 1 # 计算位置编码表的大小。
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) # 创建一个可训练的参数 rpe_table,它是一个形状为 (3 * rpe_num, num_heads) 的张量,用于存储相对位置编码的表格。这个表格的每一列对应一个注意力头。
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) # 使用截断正态分布初始化相对位置编码表格的值。

def forward(self, coord):
    idx = (
        coord.clamp(-self.pos_bnd, self.pos_bnd)  # clamp into bnd
        + self.pos_bnd  # relative position to positive index
        + torch.arange(3, device=coord.device) * self.rpe_num  # x, y, z stride
    )
    out = self.rpe_table.index_select(0, idx.reshape(-1))
    out = out.view(idx.shape + (-1,)).sum(3)
    out = out.permute(0, 3, 1, 2)  # (N, K, K, H) -> (N, H, K, K)
    return out

class SerializedAttention(PointModule):
def init(
self,
channels,
num_heads,
patch_size,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
order_index=0,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
):
super().init()
assert channels % num_heads == 0
self.channels = channels
self.num_heads = num_heads
self.scale = qk_scale or (channels // num_heads) ** -0.5
self.order_index = order_index
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.enable_rpe = enable_rpe
self.enable_flash = enable_flash
if enable_flash:
assert (
enable_rpe is False
), "Set enable_rpe to False when enable Flash Attention"
assert (
upcast_attention is False
), "Set upcast_attention to False when enable Flash Attention"
assert (
upcast_softmax is False
), "Set upcast_softmax to False when enable Flash Attention"
assert flash_attn is not None, "Make sure flash_attn is installed."
self.patch_size = patch_size
self.attn_drop = attn_drop
else:
# when disable flash attention, we still don't want to use mask
# consequently, patch size will auto set to the
# min number of patch_size_max and number of points
self.patch_size_max = patch_size
self.patch_size = 0
self.attn_drop = torch.nn.Dropout(attn_drop)

    self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
    self.proj = torch.nn.Linear(channels, channels)
    self.proj_drop = torch.nn.Dropout(proj_drop)
    self.softmax = torch.nn.Softmax(dim=-1)
    self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None

@torch.no_grad()
def get_rel_pos(self, point, order):
    K = self.patch_size
    rel_pos_key = f"rel_pos_{self.order_index}"
    if rel_pos_key not in point.keys():
        grid_coord = point.grid_coord[order]
        grid_coord = grid_coord.reshape(-1, K, 3)
        point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
    return point[rel_pos_key]

@torch.no_grad()
def get_padding_and_inverse(self, point):
    pad_key = "pad"
    unpad_key = "unpad"
    cu_seqlens_key = "cu_seqlens_key"
    if (
        pad_key not in point.keys()
        or unpad_key not in point.keys()
        or cu_seqlens_key not in point.keys()
    ):
        offset = point.offset
        bincount = offset2bincount(offset)
        bincount_pad = (
            torch.div(
                bincount + self.patch_size - 1,
                self.patch_size,
                rounding_mode="trunc",
            )
            * self.patch_size
        )
        # only pad point when num of points larger than patch_size
        mask_pad = bincount > self.patch_size
        bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
        _offset = nn.functional.pad(offset, (1, 0))
        _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
        pad = torch.arange(_offset_pad[-1], device=offset.device)
        unpad = torch.arange(_offset[-1], device=offset.device)
        cu_seqlens = []
        for i in range(len(offset)):
            unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
            if bincount[i] != bincount_pad[i]:
                pad[
                    _offset_pad[i + 1]
                    - self.patch_size
                    + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
                ] = pad[
                    _offset_pad[i + 1]
                    - 2 * self.patch_size
                    + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
                    - self.patch_size
                ]
            pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
            cu_seqlens.append(
                torch.arange(
                    _offset_pad[i],
                    _offset_pad[i + 1],
                    step=self.patch_size,
                    dtype=torch.int32,
                    device=offset.device,
                )
            )
        point[pad_key] = pad
        point[unpad_key] = unpad
        point[cu_seqlens_key] = nn.functional.pad(
            torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
        )
    return point[pad_key], point[unpad_key], point[cu_seqlens_key]

def forward(self, point):
    if not self.enable_flash:
        self.patch_size = min(
            offset2bincount(point.offset).min().tolist(), self.patch_size_max
        )

    H = self.num_heads
    K = self.patch_size
    C = self.channels

    pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)

    order = point.serialized_order[self.order_index][pad]
    inverse = unpad[point.serialized_inverse[self.order_index]]

    # padding and reshape feat and batch for serialized point patch
    qkv = self.qkv(point.feat)[order]

    if not self.enable_flash:
        # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
        q, k, v = (
            qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
        )
        # attn
        if self.upcast_attention:
            q = q.float()
            k = k.float()
        attn = (q * self.scale) @ k.transpose(-2, -1)  # (N', H, K, K)
        if self.enable_rpe:
            attn = attn + self.rpe(self.get_rel_pos(point, order))
        if self.upcast_softmax:
            attn = attn.float()
        attn = self.softmax(attn)
        attn = self.attn_drop(attn).to(qkv.dtype)
        feat = (attn @ v).transpose(1, 2).reshape(-1, C)
    else:
        feat = flash_attn.flash_attn_varlen_qkvpacked_func(
            qkv.half().reshape(-1, 3, H, C // H),
            cu_seqlens,
            max_seqlen=self.patch_size,
            dropout_p=self.attn_drop if self.training else 0,
            softmax_scale=self.scale,
        ).reshape(-1, C)
        feat = feat.to(qkv.dtype)
    feat = feat[inverse]

    # ffn
    feat = self.proj(feat)
    feat = self.proj_drop(feat)
    point.feat = feat
    return point

class MLP(nn.Module):
def init(
self,
in_channels,
hidden_channels=None,
out_channels=None,
act_layer=nn.GELU,
drop=0.0,
):
super().init()
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.fc1 = nn.Linear(in_channels, hidden_channels)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_channels, out_channels)
self.drop = nn.Dropout(drop)

def forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.drop(x)
    x = self.fc2(x)
    x = self.drop(x)
    return x

class Block(PointModule):
def init(
self,
channels,
num_heads,
patch_size=48,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
pre_norm=True,
order_index=0,
cpe_indice_key=None,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
):
super().init()
self.channels = channels
self.pre_norm = pre_norm

    self.cpe = PointSequential(
        spconv.SubMConv3d(
            channels,
            channels,
            kernel_size=3,
            bias=True,
            indice_key=cpe_indice_key,
        ),
        nn.Linear(channels, channels),
        norm_layer(channels),
    )

    self.norm1 = PointSequential(norm_layer(channels))
    self.attn = SerializedAttention(
        channels=channels,
        patch_size=patch_size,
        num_heads=num_heads,
        qkv_bias=qkv_bias,
        qk_scale=qk_scale,
        attn_drop=attn_drop,
        proj_drop=proj_drop,
        order_index=order_index,
        enable_rpe=enable_rpe,
        enable_flash=enable_flash,
        upcast_attention=upcast_attention,
        upcast_softmax=upcast_softmax,
    )
    self.norm2 = PointSequential(norm_layer(channels))
    self.mlp = PointSequential(
        MLP(
            in_channels=channels,
            hidden_channels=int(channels * mlp_ratio),
            out_channels=channels,
            act_layer=act_layer,
            drop=proj_drop,
        )
    )
    self.drop_path = PointSequential(
        DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    )

def forward(self, point: Point):
    shortcut = point.feat
    point = self.cpe(point)
    point.feat = shortcut + point.feat
    shortcut = point.feat
    if self.pre_norm:
        point = self.norm1(point)
    point = self.drop_path(self.attn(point))
    point.feat = shortcut + point.feat
    if not self.pre_norm:
        point = self.norm1(point)

    shortcut = point.feat
    if self.pre_norm:
        point = self.norm2(point)
    point = self.drop_path(self.mlp(point))
    point.feat = shortcut + point.feat
    if not self.pre_norm:
        point = self.norm2(point)
    point.sparse_conv_feat.replace_feature(point.feat)
    return point

class SerializedPooling(PointModule):
def init(
self,
in_channels,
out_channels,
stride=2,
norm_layer=None,
act_layer=None,
reduce="max",
shuffle_orders=True,
traceable=True, # record parent and cluster
):
super().init()
self.in_channels = in_channels
self.out_channels = out_channels

    assert stride == 2 ** (math.ceil(stride) - 1).bit_length()  # 2, 4, 8
    # TODO: add support to grid pool (any stride)
    self.stride = stride
    assert reduce in ["sum", "mean", "min", "max"]
    self.reduce = reduce
    self.shuffle_orders = shuffle_orders
    self.traceable = traceable

    self.proj = nn.Linear(in_channels, out_channels)
    if norm_layer is not None:
        self.norm = PointSequential(norm_layer(out_channels))
    if act_layer is not None:
        self.act = PointSequential(act_layer())

def forward(self, point: Point):
    pooling_depth = (math.ceil(self.stride) - 1).bit_length()
    if pooling_depth > point.serialized_depth:
        pooling_depth = 0
    assert {
        "serialized_code",
        "serialized_order",
        "serialized_inverse",
        "serialized_depth",
    }.issubset(
        point.keys()
    ), "Run point.serialization() point cloud before SerializedPooling"

    code = point.serialized_code >> pooling_depth * 3
    code_, cluster, counts = torch.unique(
        code[0],
        sorted=True,
        return_inverse=True,
        return_counts=True,
    )
    # indices of point sorted by cluster, for torch_scatter.segment_csr
    _, indices = torch.sort(cluster)
    # index pointer for sorted point, for torch_scatter.segment_csr
    idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
    # head_indices of each cluster, for reduce attr e.g. code, batch
    head_indices = indices[idx_ptr[:-1]]
    # generate down code, order, inverse
    code = code[:, head_indices]
    order = torch.argsort(code)
    inverse = torch.zeros_like(order).scatter_(
        dim=1,
        index=order,
        src=torch.arange(0, code.shape[1], device=order.device).repeat(
            code.shape[0], 1
        ),
    )

    if self.shuffle_orders:
        perm = torch.randperm(code.shape[0])
        code = code[perm]
        order = order[perm]
        inverse = inverse[perm]

    # collect information
    point_dict = Dict(
        feat=torch_scatter.segment_csr(
            self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
        ),
        coord=torch_scatter.segment_csr(
            point.coord[indices], idx_ptr, reduce="mean"
        ),
        grid_coord=point.grid_coord[head_indices] >> pooling_depth,
        serialized_code=code,
        serialized_order=order,
        serialized_inverse=inverse,
        serialized_depth=point.serialized_depth - pooling_depth,
        batch=point.batch[head_indices],
    )

    if "condition" in point.keys():
        point_dict["condition"] = point.condition
    if "context" in point.keys():
        point_dict["context"] = point.context

    if self.traceable:
        point_dict["pooling_inverse"] = cluster
        point_dict["pooling_parent"] = point
    point = Point(point_dict)
    if self.norm is not None:
        point = self.norm(point)
    if self.act is not None:
        point = self.act(point)
    point.sparsify()
    return point

class SerializedUnpooling(PointModule):
def init(
self,
in_channels,
skip_channels,
out_channels,
norm_layer=None,
act_layer=None,
traceable=False, # record parent and cluster
):
super().init()
self.proj = PointSequential(nn.Linear(in_channels, out_channels))
self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))

    if norm_layer is not None:
        self.proj.add(norm_layer(out_channels))
        self.proj_skip.add(norm_layer(out_channels))

    if act_layer is not None:
        self.proj.add(act_layer())
        self.proj_skip.add(act_layer())

    self.traceable = traceable

def forward(self, point):
    assert "pooling_parent" in point.keys()
    assert "pooling_inverse" in point.keys()
    parent = point.pop("pooling_parent")
    inverse = point.pop("pooling_inverse")
    point = self.proj(point)
    parent = self.proj_skip(parent)
    parent.feat = parent.feat + point.feat[inverse]

    if self.traceable:
        parent["unpooling_parent"] = point
    return parent

class Embedding(PointModule):
def init(
self,
in_channels,
embed_channels,
norm_layer=None,
act_layer=None,
):
super().init()
self.in_channels = in_channels
self.embed_channels = embed_channels

    # TODO: check remove spconv
    self.stem = PointSequential(
        conv=spconv.SubMConv3d(
            in_channels,
            embed_channels,
            kernel_size=5,
            padding=1,
            bias=False,
            indice_key="stem",
        )
    )
    if norm_layer is not None:
        self.stem.add(norm_layer(embed_channels), name="norm")
    if act_layer is not None:
        self.stem.add(act_layer(), name="act")

def forward(self, point: Point):
    point = self.stem(point)
    return point

@MODELS.register_module("PT-v3m1")
class PointTransformerV3(PointModule):
def init(
self,
in_channels=4, ###only have intensity imfomation
order=("z", "z_trans"),
stride=(2, 2, 2, 2),
enc_depths=(2, 2, 2, 6, 2),
enc_channels=(32, 64, 128, 256, 512),
enc_num_head=(2, 4, 8, 16, 32),
enc_patch_size=(48, 48, 48, 48, 48),
dec_depths=(2, 2, 2, 2),
dec_channels=(64, 64, 128, 256),
dec_num_head=(4, 4, 8, 16),
dec_patch_size=(48, 48, 48, 48),
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.3,
pre_norm=True,
shuffle_orders=True,
enable_rpe=False,
enable_flash=True,
upcast_attention=True,
upcast_softmax=True,
cls_mode=False,
pdnorm_bn=False,
pdnorm_ln=False,
pdnorm_decouple=True,
pdnorm_adaptive=False,
pdnorm_affine=True,
pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"),
):
super().init()
self.num_stages = len(enc_depths)
self.order = [order] if isinstance(order, str) else order
self.cls_mode = cls_mode
self.shuffle_orders = shuffle_orders

    assert self.num_stages == len(stride) + 1
    assert self.num_stages == len(enc_depths)
    assert self.num_stages == len(enc_channels)
    assert self.num_stages == len(enc_num_head)
    assert self.num_stages == len(enc_patch_size)
    assert self.cls_mode or self.num_stages == len(dec_depths) + 1
    assert self.cls_mode or self.num_stages == len(dec_channels) + 1
    assert self.cls_mode or self.num_stages == len(dec_num_head) + 1
    assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1

    # norm layers
    if pdnorm_bn:
        bn_layer = partial(
            PDNorm,
            norm_layer=partial(
                nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine
            ),
            conditions=pdnorm_conditions,
            decouple=pdnorm_decouple,
            adaptive=pdnorm_adaptive,
        )
    else:
        bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
    if pdnorm_ln:
        ln_layer = partial(
            PDNorm,
            norm_layer=partial(nn.LayerNorm, elementwise_affine=pdnorm_affine),
            conditions=pdnorm_conditions,
            decouple=pdnorm_decouple,
            adaptive=pdnorm_adaptive,
        )
    else:
        ln_layer = nn.LayerNorm
    # activation layers
    act_layer = nn.GELU

    self.embedding = Embedding(
        in_channels=in_channels,
        embed_channels=enc_channels[0],
        norm_layer=bn_layer,
        act_layer=act_layer,
    )

    # encoder
    enc_drop_path = [
        x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
    ]
    self.enc = PointSequential()
    for s in range(self.num_stages):
        enc_drop_path_ = enc_drop_path[
            sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
        ]
        enc = PointSequential()
        if s > 0:
            enc.add(
                SerializedPooling(
                    in_channels=enc_channels[s - 1],
                    out_channels=enc_channels[s],
                    stride=stride[s - 1],
                    norm_layer=bn_layer,
                    act_layer=act_layer,
                ),
                name="down",
            )
        for i in range(enc_depths[s]):
            enc.add(
                Block(
                    channels=enc_channels[s],
                    num_heads=enc_num_head[s],
                    patch_size=enc_patch_size[s],
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    attn_drop=attn_drop,
                    proj_drop=proj_drop,
                    drop_path=enc_drop_path_[i],
                    norm_layer=ln_layer,
                    act_layer=act_layer,
                    pre_norm=pre_norm,
                    order_index=i % len(self.order),
                    cpe_indice_key=f"stage{s}",
                    enable_rpe=enable_rpe,
                    enable_flash=enable_flash,
                    upcast_attention=upcast_attention,
                    upcast_softmax=upcast_softmax,
                ),
                name=f"block{i}",
            )
        if len(enc) != 0:
            self.enc.add(module=enc, name=f"enc{s}")

    # decoder
    if not self.cls_mode:
        dec_drop_path = [
            x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
        ]
        self.dec = PointSequential()
        dec_channels = list(dec_channels) + [enc_channels[-1]]
        for s in reversed(range(self.num_stages - 1)):
            dec_drop_path_ = dec_drop_path[
                sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
            ]
            dec_drop_path_.reverse()
            dec = PointSequential()
            dec.add(
                SerializedUnpooling(
                    in_channels=dec_channels[s + 1],
                    skip_channels=enc_channels[s],
                    out_channels=dec_channels[s],
                    norm_layer=bn_layer,
                    act_layer=act_layer,
                ),
                name="up",
            )
            for i in range(dec_depths[s]):
                dec.add(
                    Block(
                        channels=dec_channels[s],
                        num_heads=dec_num_head[s],
                        patch_size=dec_patch_size[s],
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        attn_drop=attn_drop,
                        proj_drop=proj_drop,
                        drop_path=dec_drop_path_[i],
                        norm_layer=ln_layer,
                        act_layer=act_layer,
                        pre_norm=pre_norm,
                        order_index=i % len(self.order),
                        cpe_indice_key=f"stage{s}",
                        enable_rpe=enable_rpe,
                        enable_flash=enable_flash,
                        upcast_attention=upcast_attention,
                        upcast_softmax=upcast_softmax,
                    ),
                    name=f"block{i}",
                )
            self.dec.add(module=dec, name=f"dec{s}")

def forward(self, data_dict):
    point = Point(data_dict)
    point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
    point.sparsify()

    point = self.embedding(point)
    point = self.enc(point)
    if not self.cls_mode:
        point = self.dec(point)
    # else:
    #     point.feat = torch_scatter.segment_csr(
    #         src=point.feat,
    #         indptr=nn.functional.pad(point.offset, (1, 0)),
    #         reduce="mean",
    #     )
    return point

Thank you so much for your reply again!

i concate intensity to "color" in your codebase, which means the "color" information is my intensity information.

Hi, I think you can try to assign the name "strength" to intensity information and refer to our config for scenes to modify the config (https://github.com/Pointcept/Pointcept/blob/main/configs/nuscenes/semseg-pt-v3m1-0-base.py#L120-L124).

i concate intensity to "color" in your codebase, which means the "color" information is my intensity information.

Hi, I think you can try to assign the name "strength" to intensity information and refer to our config for scenes to modify the config (https://github.com/Pointcept/Pointcept/blob/main/configs/nuscenes/semseg-pt-v3m1-0-base.py#L120-L124).

Thank you so much for your patient reply! I will get a try.

i concate intensity to "color" in your codebase, which means the "color" information is my intensity information.

Hi, I think you can try to assign the name " #strength" to intensity information and refer to our config for scenes to modify the config (https://github.com/Pointcept/Pointcept/blob/main/configs/nuscenes/semseg-pt-v3m1-0-base.py#L120-L124).

Hi. I have another question to bother you. You said assign the name "strength" to intensity information and modify the config, so do i need to preprocess the custom datasets by prprocess_s3dis.py, which modified "color" to "strength"? Like this way.

class2label = {cls: i for i, cls in enumerate(classes)}
source_dir = os.path.join(dataset_root, room)
save_path = os.path.join(output_root, room) + ".pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
object_path_list = sorted(glob.glob(os.path.join(source_dir, "Annotations/*.txt")))

room_coords = []
**room_strength** = []
room_normals = []
room_semantic_gt = []
room_instance_gt = []

for object_id, object_path in enumerate(object_path_list):
    object_name = os.path.basename(object_path).split("_")[0]
    obj = np.loadtxt(object_path)
    coords = obj[:, :3]
    **strength** = obj[:, 3:4]
    # note: in some room there is 'stairs' class
    class_name = object_name if object_name in classes else "clutter"
    semantic_gt = np.repeat(class2label[class_name], coords.shape[0])
    semantic_gt = semantic_gt.reshape([-1, 1])
    instance_gt = np.repeat(object_id, coords.shape[0])
    instance_gt = instance_gt.reshape([-1, 1])

    room_coords.append(coords)
    **room_strength**.append**(strength)**
    room_semantic_gt.append(semantic_gt)
    room_instance_gt.append(instance_gt)

@Gofinge Hi, sorry to bother you, but I really can't find the problem after two days of debugging, and I really need your help:

I have assigned the point cloud intensity information to "strength" in the data preprocessing script as you suggested, and correctly set in_channles to 4 in the config file. Strangely, the error message is still.
assert input.features.shape[
AssertionError: channel size mismatch

Here is my config file:
weight = None
resume = False
evaluate = True
test_only = False
seed = 2737042
save_path = 'exp/s3dis/s3dis-pt-v3m1-0-base-leakage'
num_worker = 4
batch_size = 2
batch_size_val = None
batch_size_test = None
epoch = 100
eval_epoch = 100
sync_bn = False
enable_amp = True
empty_cache = False
find_unused_parameters = False
mix_prob = 0.8
param_dicts = [dict(keyword='block', lr=0.0006)]
hooks = [
dict(type='CheckpointLoader'),
dict(type='IterationTimer', warmup_iter=2),
dict(type='InformationWriter'),
dict(type='SemSegEvaluator'),
dict(type='CheckpointSaver', save_freq=None),
dict(type='PreciseEvaluator', test_last=False)
]
train = dict(type='DefaultTrainer')
test = dict(type='SemSegTester', verbose=True)
model = dict(
type='DefaultSegmentorV2',
num_classes=2,
backbone_out_channels=64,
backbone=dict(
type='PT-v3m1',
in_channels=4,
order=['z', 'z-trans', 'hilbert', 'hilbert-trans'],
stride=(2, 2, 2, 2),
enc_depths=(2, 2, 2, 6, 2),
enc_channels=(32, 64, 128, 256, 512),
enc_num_head=(2, 4, 8, 16, 32),
enc_patch_size=(1024, 1024, 1024, 1024, 1024),
dec_depths=(2, 2, 2, 2),
dec_channels=(64, 64, 128, 256),
dec_num_head=(4, 4, 8, 16),
dec_patch_size=(1024, 1024, 1024, 1024),
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
drop_path=0.3,
shuffle_orders=True,
pre_norm=True,
enable_rpe=False,
enable_flash=True,
upcast_attention=False,
upcast_softmax=False,
cls_mode=False,
pdnorm_bn=False,
pdnorm_ln=False,
pdnorm_decouple=True,
pdnorm_adaptive=False,
pdnorm_affine=True,
pdnorm_conditions=('ScanNet', 'S3DIS', 'Structured3D')),
criteria=[
dict(type='CrossEntropyLoss', loss_weight=1.0, ignore_index=-1),
dict(
type='LovaszLoss',
mode='multiclass',
loss_weight=1.0,
ignore_index=-1)
])
optimizer = dict(type='AdamW', lr=0.006, weight_decay=0.05)
scheduler = dict(
type='OneCycleLR',
max_lr=[0.006, 0.0006],
pct_start=0.05,
anneal_strategy='cos',
div_factor=10.0,
final_div_factor=1000.0)
dataset_type = 'S3DISDataset'
data_root = '/media/jiang/data_tunnel/#DATA/S3DIS_leakage/S3DIS_leakage_processed_strength'
data = dict(
num_classes=2,
ignore_index=-1,
names=['leakage', 'background'],
train=dict(
type='S3DISDataset',
split=('Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6'),
data_root=
'/media/jiang/data_tunnel/#DATA/S3DIS_leakage/S3DIS_leakage_processed_strength',
transform=[
dict(type='CenterShift', apply_z=True),
dict(
type='RandomDropout',
dropout_ratio=0.2,
dropout_application_ratio=0.2),
dict(
type='RandomRotate',
angle=[-1, 1],
axis='z',
center=[0, 0, 0],
p=0.5),
dict(
type='RandomRotate',
angle=[-0.015625, 0.015625],
axis='x',
p=0.5),
dict(
type='RandomRotate',
angle=[-0.015625, 0.015625],
axis='y',
p=0.5),
dict(type='RandomScale', scale=[0.9, 1.1]),
dict(type='RandomFlip', p=0.5),
dict(type='RandomJitter', sigma=0.005, clip=0.02),
dict(type='ChromaticAutoContrast', p=0.2, blend_factor=None),
dict(type='ChromaticTranslation', p=0.95, ratio=0.05),
dict(type='ChromaticJitter', p=0.95, std=0.05),
dict(
type='GridSample',
grid_size=0.02,
hash_type='fnv',
mode='train',
return_grid_coord=True),
dict(type='SphereCrop', sample_rate=0.6, mode='random'),
dict(type='SphereCrop', point_max=204800, mode='random'),
dict(type='CenterShift', apply_z=False),
dict(type='NormalizeColor'),
dict(type='ToTensor'),
dict(
type='Collect',
keys=('coord', 'grid_coord', 'segment'),
feat_keys=('strength', ))
],
test_mode=False,
loop=1),
val=dict(
type='S3DISDataset',
split='Area_5',
data_root=
'/media/jiang/data_tunnel/#DATA/S3DIS_leakage/S3DIS_leakage_processed_strength',
transform=[
dict(type='CenterShift', apply_z=True),
dict(
type='Copy',
keys_dict=dict(coord='origin_coord',
segment='origin_segment')),
dict(
type='GridSample',
grid_size=0.02,
hash_type='fnv',
mode='train',
return_grid_coord=True),
dict(type='CenterShift', apply_z=False),
dict(type='NormalizeColor'),
dict(type='ToTensor'),
dict(
type='Collect',
keys=('coord', 'grid_coord', 'origin_coord', 'segment',
'origin_segment'),
offset_keys_dict=dict(
offset='coord', origin_offset='origin_coord'),
feat_keys=('strength', ))
],
test_mode=False),
test=dict(
type='S3DISDataset',
split='Area_5',
data_root=
'/media/jiang/data_tunnel/#DATA/S3DIS_leakage/S3DIS_leakage_processed_strength',
transform=[
dict(type='CenterShift', apply_z=True),
dict(type='NormalizeColor')
],
test_mode=True,
test_cfg=dict(
voxelize=dict(
type='GridSample',
grid_size=0.02,
hash_type='fnv',
mode='test',
keys=('coord', 'strength'),
return_grid_coord=True),
crop=None,
post_transform=[
dict(type='CenterShift', apply_z=False),
dict(type='ToTensor'),
dict(
type='Collect',
keys=('coord', 'grid_coord', 'index'),
feat_keys=('strength', ))
],
aug_transform=[[{
'type': 'RandomScale',
'scale': [0.9, 0.9]
}], [{
'type': 'RandomScale',
'scale': [0.95, 0.95]
}], [{
'type': 'RandomScale',
'scale': [1, 1]
}], [{
'type': 'RandomScale',
'scale': [1.05, 1.05]
}], [{
'type': 'RandomScale',
'scale': [1.1, 1.1]
}],
[{
'type': 'RandomScale',
'scale': [0.9, 0.9]
}, {
'type': 'RandomFlip',
'p': 1
}],
[{
'type': 'RandomScale',
'scale': [0.95, 0.95]
}, {
'type': 'RandomFlip',
'p': 1
}],
[{
'type': 'RandomScale',
'scale': [1, 1]
}, {
'type': 'RandomFlip',
'p': 1
}],
[{
'type': 'RandomScale',
'scale': [1.05, 1.05]
}, {
'type': 'RandomFlip',
'p': 1
}],
[{
'type': 'RandomScale',
'scale': [1.1, 1.1]
}, {
'type': 'RandomFlip',
'p': 1
}]])))

@Gofinge Hi, I changed my strategy, I copied the intensity information twice and still represented them with "color", that is, my RGB information was represented as intensity, intensity, intensity, respectively.

But when I set in_channels to 6, the "AssertionError: channel size mismatch" error still occurred. I tried to use torch.load to output the shape of my data set, and the result showed (17620072, 3) (17620072, 3). Obviously there is no problem with my dataset.

data = torch.load(sample_data_path)
print(data['coord'].shape)
print(data['color'].shape)

Finally I set in_channels to 3, and surprisingly, it worked, so I added some print statements to xxx.py to print the output of each layer, like this:

def forward(self, data_dict):
    print("Initial data keys:", data_dict.keys())  # 打印数据字典的键
    print("Initial coord example:", data_dict['coord'][:10])  # 显示前10个坐标点的示例,假设 'coord' 是坐标键

    point = Point(data_dict)
    point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
    point.sparsify()

    print("Serialized and sparsified point features shape:", point.feat.shape)  # 检查特征形状

    point = self.embedding(point)
    print("After embedding features shape:", point.feat.shape)  # 输出嵌入后的特征形状

    point = self.enc(point)
    if not self.cls_mode:
        point = self.dec(point)

    print("Output features shape:", point.feat.shape)  # 检查输出特征形状
    return point

Here is the results:

Initial data keys: dict_keys(['coord', 'grid_coord', 'segment', 'offset', 'feat', 'batch'])
Initial coord example: tensor([
        [ 1.0290e-03,  1.2180e+00,  2.2660e+00],
        [ 7.1236e-03,  1.2098e+00,  2.2660e+00],
        [ 6.5242e-03,  1.2126e+00,  2.2817e+00],
        [-1.7980e-02,  1.2176e+00,  2.2636e+00],
        [ 5.9669e-03,  1.2181e+00,  2.2455e+00],
        [ 1.4903e-03,  1.2191e+00,  2.2874e+00],
        [ 1.4696e-02,  1.2027e+00,  2.2745e+00],
        [ 2.3315e-02,  1.2162e+00,  2.2593e+00],
        [ 1.9005e-02,  1.2236e+00,  2.2807e+00],
        [ 2.5252e-02,  1.2189e+00,  2.2722e+00]], device='cuda:0')
Serialized and sparsified point features shape: torch.Size([377402, 3])
Before embedding: torch.Size([377402, 3])
After embedding: torch.Size([377402, 32])
After embedding features shape: torch.Size([377402, 32])
Block input feature shape: torch.Size([377402, 32])
Attention input features shape: torch.Size([377402, 32])
Attention output features shape: torch.Size([377402, 32])
Block output feature shape: torch.Size([377402, 32])
Block input feature shape: torch.Size([377402, 32])
Attention input features shape: torch.Size([377402, 32])
Attention output features shape: torch.Size([377402, 32])
Block output feature shape: torch.Size([377402, 32])
Block input feature shape: torch.Size([83636, 64])
Attention input features shape: torch.Size([83636, 64])
Attention output features shape: torch.Size([83636, 64])
Block output feature shape: torch.Size([83636, 64])
Block input feature shape: torch.Size([83636, 64])
Attention input features shape: torch.Size([83636, 64])
Attention output features shape: torch.Size([83636, 64])
Block output feature shape: torch.Size([83636, 64])
Block input feature shape: torch.Size([19583, 128])
Attention input features shape: torch.Size([19583, 128])
Attention output features shape: torch.Size([19583, 128])
Block output feature shape: torch.Size([19583, 128])
Block input feature shape: torch.Size([19583, 128])
Attention input features shape: torch.Size([19583, 128])
Attention output features shape: torch.Size([19583, 128])
Block output feature shape: torch.Size([19583, 128])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([1076, 512])
Attention input features shape: torch.Size([1076, 512])
Attention output features shape: torch.Size([1076, 512])
Block output feature shape: torch.Size([1076, 512])
Block input feature shape: torch.Size([1076, 512])
Attention input features shape: torch.Size([1076, 512])
Attention output features shape: torch.Size([1076, 512])
Block output feature shape: torch.Size([1076, 512])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([4554, 256])
Attention input features shape: torch.Size([4554, 256])
Attention output features shape: torch.Size([4554, 256])
Block output feature shape: torch.Size([4554, 256])
Block input feature shape: torch.Size([19583, 128])
Attention input features shape: torch.Size([19583, 128])
Attention output features shape: torch.Size([19583, 128])
Block output feature shape: torch.Size([19583, 128])
Block input feature shape: torch.Size([19583, 128])
Attention input features shape: torch.Size([19583, 128])
Attention output features shape: torch.Size([19583, 128])
Block output feature shape: torch.Size([19583, 128])
Block input feature shape: torch.Size([83636, 64])
Attention input features shape: torch.Size([83636, 64])
Attention output features shape: torch.Size([83636, 64])
Block output feature shape: torch.Size([83636, 64])
Block input feature shape: torch.Size([83636, 64])
Attention input features shape: torch.Size([83636, 64])
Attention output features shape: torch.Size([83636, 64])
Block output feature shape: torch.Size([83636, 64])
Block input feature shape: torch.Size([377402, 64])
Attention input features shape: torch.Size([377402, 64])
Attention output features shape: torch.Size([377402, 64])
Block output feature shape: torch.Size([377402, 64])
Block input feature shape: torch.Size([377402, 64])
Attention input features shape: torch.Size([377402, 64])
Attention output features shape: torch.Size([377402, 64])
Block output feature shape: torch.Size([377402, 64])
Output features shape: torch.Size([377402, 64])

So I have a question: Does in_channel represent the dimension of complete data or the dimension of feature? If it is the former, why is in_channels set to 6, It doesn't work?
I still can't figure the error out, however, its very important for me, because i will try to add other features as input feature, until then, I still need to figure this out. So I sincerely hope you can help me find the problem, thanks for your time and kind help in advance!!