houqb / CoordAttention

Code for our CVPR2021 paper coordinate attention

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

I wanna ask that How to modify the code to train 3D images?

InvincibleXiao opened this issue · comments

Do you mean point cloud data?

Do you mean point cloud data?

I mean 3D medical images。

I think it is correct.

I think it is correct.

I am not sure about why we concat at the dim=2.If I concat the 3D images at dim=2 .Is that work?

It needs a try.

It needs a try.

I also want to use it in 3D images. I imitated the code under the 2D images to rewrite it into the 3D. But the main problem I met is 【torch.cat()】

Assume I have a 3D dataset, the shape of input is [2,32,112,160,128], which indicates [batch,channel,height,width,depth].
And I use:
self.pool_h = nn.AdaptiveAvgPool3d((None, None,1))
self.pool_w = nn.AdaptiveAvgPool3d((None,1,None))
self.pool_d = nn.AdaptiveAvgPool3d((1,None,None))
......
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 2, 4, 3)
x_d = self.pool_d(x).permute(0, 1, 4, 3, 2)

Then I get:
x_h.shape == [2,32,112,160,1]
x_w.shape == [2,32,112,128,1]
x_d.shape == [2,32,128,160,1]

Now, I can use torch.cat([x_h,x_w],3) to get a tensor like [2,32,112,288,1], but I don't know how to concatenate [x_h,x_w,x_d]. Can you give some advice?