Meowuu7 / GeneOH-Diffusion

[ICLR'24] GeneOH Diffusion: Towards Generalizable Hand-Object Interaction Denoising via Denoising Diffusion

Home Page:https://meowuu7.github.io/GeneOH-Diffusion/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RuntimeError in fps (Torch Cluster) Function

Alptekir opened this issue · comments

I encountered a RuntimeError when using the fps function from the torch_cluster module. The error message suggests an internal assertion failure in fps_cpu.cpp due to an unexpected tensor dimension. Here are the details:

Error Traceback:

File "/home1/Download/GeneOH-Diffusion-main/data_loaders/humanml/data/utils.py", line 214, in farthest_point_sampling
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
File "/home1/programs/miniconda3/envs/geneoh/lib/python3.8/site-packages/torch_cluster/fps.py", line 107, in fps
return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)
File "/home1/programs/miniconda3/envs/geneoh/lib/python3.8/site-packages/torch/_ops.py", line 854, in call
return self._op(*args, **(kwargs or {}))
RuntimeError: ptr.dim() == 1 INTERNAL ASSERT FAILED at "csrc/cpu/fps_cpu.cpp":17, Input mismatch

Input Shapes:

pos_float shape: torch.Size([3109, 3])
batch shape: torch.Size([3109])

Could you please provide guidance on how to resolve this issue? If there are any specific checks or modifications needed in the code, I would appreciate your assistance.

Hi @Alptekir ,

I haven't encountered this issue before. How do you install the torch_cluster package? Additionally, could you please provide more details about the input to the farthest_point_sampling function, specifically the shape of the input tensors (pos and n_sampling)? If possible, could you export these tensors and upload them here? Thanks~

I installed torch_cluster using the provided wheel. The input shape for the farthest_point_sampling function in the HOID4 Toycar example is as follows:

pos shape: torch.Size([1, 13465, 3])
nn_sampling shape: 700

I have attached the pos.zip

here is the full error code.

Original Traceback (most recent call last):
File "/home1/programs/miniconda3/envs/geneoh/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
File "/home1/programs/miniconda3/envs/geneoh/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home1/programs/miniconda3/envs/geneoh/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home1/Download/GeneOH-Diffusion-main/data_loaders/humanml/data/dataset_ours_single_seq.py", line 2151, in getitem
base_pts_idxes = utils.farthest_point_sampling(base_pts.unsqueeze(0), n_sampling=nn_base_pts)
File "/home1/Download/GeneOH-Diffusion-main/data_loaders/humanml/data/utils.py", line 214, in farthest_point_sampling
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
File "/home1/programs/miniconda3/envs/geneoh/lib/python3.8/site-packages/torch_cluster/fps.py", line 107, in fps
return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)
File "/home1/programs/miniconda3/envs/geneoh/lib/python3.8/site-packages/torch/ops.py", line 854, in call
return self
._op(*args, **(kwargs or {}))
RuntimeError: ptr.dim() == 1 INTERNAL ASSERT FAILED at "csrc/cpu/fps_cpu.cpp":17, Input mismatch

I've tried to load pos.pth and call the farthest_point_sampling function using the following scripts.

import torch
from torch_cluster import fps

def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
  bz, N = pos.size(0), pos.size(1)
  feat_dim = pos.size(-1)
  device = pos.device
  sampling_ratio = float(n_sampling / N)
  pos_float = pos.float()

  batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
  mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)

  batch = batch * mult_one
  batch = batch.view(-1)
  pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
  # sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
  # batch = torch.zeros((N, ), dtype=torch.long, device=device)
  sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
  return sampled_idx

if __name__=='__main__':
    pos = torch.load("./data/pos.pth")
    print(pos.size(), pos.device)
    fps_idx = farthest_point_sampling(pos, 700)
    print(fps_idx.size())

It works well and the output is:

torch.Size([1, 13465, 3]) cpu
torch.Size([700])

Therefore, it seems likely that the error you encountered is related to your environment. Could you please export the details of your environment, for instance, by running conda env export > environment.yml, and then upload the resulting environment.yml file?
Thank you!

Thank you for your help. I have attached my environment in the zip file.
environment.zip
Could I also ask about the torch-cluster FPS function? I couldn't find this function on their website.

Thank you for your help. I think the issue is related to the PyTorch and CUDA versions. After switching environments to specifically PyTorch 2.2.0 and CUDA 12.1, the problem was solved.

Thanks for your solution! I suspected that the issue was due to a version mismatch between torch (2.3) and torch_cluster (2.2) when I inspected your environment package versions yesterday. However, I hadn't found the time to test it yet. It seems that running the command pip3 install torch torchvision torchaudio now installs torch with version 2.3. I've updated the command in README.md to avoid future problems.

Regarding the fps function, it samples nn_samples points from the input point cloud, greedily maximizing the distance between sampled points. This method can uniformly sample points from a point cloud with non-uniform density. We use it here to downsample a point cloud to a smaller one with nn_samples points. You can find more details in this Medium post and the DGL document.

Best regards.