xuchen-ethz / fast-snarf

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

error loading filter.cu for torch 1.11

xiyichen opened this issue · comments

Hello,
Nice work!

I am trying to integrate fast-snarf into another project which requires torch>=1.11 because I need some newer functionalities of torch. I see that you specified torch version to be 1.10.0. When I use a higher version, the file filter.cu cannot be loaded and gives the following error messages:



filter_cuda = load(name='filter',
...                    sources=[f'{cuda_dir}/filter/filter.cpp',
...                             f'{cuda_dir}/filter/filter.cu'])
Traceback (most recent call last):
  File "/opt/conda/envs/avatar/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1893, in _run_ninja_build
    subprocess.run(
  File "/opt/conda/envs/avatar/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/envs/avatar/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1284, in load
    return _jit_compile(
  File "/opt/conda/envs/avatar/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1509, in _jit_compile
    _write_ninja_file_and_build_library(
  File "/opt/conda/envs/avatar/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1624, in _write_ninja_file_and_build_library
    _run_ninja_build(
  File "/opt/conda/envs/avatar/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1909, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'filter': [1/3] /usr/local/cuda-11.3/bin/nvcc  -DTORCH_EXTENSION_NAME=filter -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda-11.3/include -isystem /opt/conda/envs/avatar/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -std=c++17 -c root/avatar/lib/cuda/filter/filter.cu -o filter.cuda.o 
FAILED: filter.cuda.o 
/usr/local/cuda-11.3/bin/nvcc  -DTORCH_EXTENSION_NAME=filter -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda-11.3/include -isystem /opt/conda/envs/avatar/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 --compiler-options '-fPIC' -std=c++17 -c root/avatar/lib/cuda/filter/filter.cu -o filter.cuda.o 
root/avatar/lib/cuda/filter/filter.cu(72): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(73): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: no instance of function template "filter" matches the argument list
            argument types are: (int, <error-type>, int, <error-type>, <error-type>, int, <error-type>, <error-type>, int, <error-type>)

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: no instance of function template "filter" matches the argument list
            argument types are: (int, <error-type>, int, <error-type>, <error-type>, int, <error-type>, <error-type>, int, <error-type>)

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: incomplete class type "at::Tensor" is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: type name is not allowed

root/avatar/lib/cuda/filter/filter.cu(77): error: argument list for class template "at::RestrictPtrTraits" is missing

root/avatar/lib/cuda/filter/filter.cu(77): error: expected an expression

root/avatar/lib/cuda/filter/filter.cu(77): error: no instance of function template "filter" matches the argument list
            argument types are: (int, <error-type>, int, <error-type>, <error-type>, int, <error-type>, <error-type>, int, <error-type>)

42 errors detected in the compilation of "root/avatar/lib/cuda/filter/filter.cu".
[2/3] c++ -MMD -MF filter.o.d -DTORCH_EXTENSION_NAME=filter -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/TH -isystem /opt/conda/envs/avatar/lib/python3.8/site-packages/torch/include/THC -isystem /usr/local/cuda-11.3/include -isystem /opt/conda/envs/avatar/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c root/avatar/lib/cuda/filter/filter.cpp -o filter.o 
ninja: build stopped: subcommand failed.

Do you know if anything in the cuda code needs to be modified to make it compatible for a torch 1.11.0?

Hi Xinyi,

If you are using a torch version later than 1.10, you can replace the filter.cu with the following modified version:

#include <vector>
#include <iostream>
#include <torch/extension.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/CUDAContext.h>
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(512)
__global__ void filter(
    const index_t nthreads,
    torch::PackedTensorAccessor32<scalar_t, 4, torch::RestrictPtrTraits> x,
    torch::PackedTensorAccessor32<bool, 3, torch::RestrictPtrTraits> mask,
    torch::PackedTensorAccessor32<bool, 3, torch::RestrictPtrTraits> output) {
    index_t n_batch = mask.size(0);
    index_t n_point = mask.size(1);
    index_t n_init = mask.size(2);
    CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
        const index_t i_batch = index / (n_batch*n_point);
        const index_t i_point = index % (n_batch*n_point);
        for(index_t i = 0; i < n_init; i++) {
            if(!mask[i_batch][i_point][i]){
                output[i_batch][i_point][i] = false;
                continue;
            }
            scalar_t xi0 = x[i_batch][i_point][i][0];
            scalar_t xi1 = x[i_batch][i_point][i][1];
            scalar_t xi2 = x[i_batch][i_point][i][2];
            bool flag = true;
            for(index_t j = i+1; j < n_init; j++){
                if(!mask[i_batch][i_point][j]){
                    continue;
                }
                scalar_t d0 = xi0 - x[i_batch][i_point][j][0];
                scalar_t d1 = xi1 - x[i_batch][i_point][j][1];
                scalar_t d2 = xi2 - x[i_batch][i_point][j][2];
                scalar_t dist = d0*d0 + d1*d1 + d2*d2;
                if(dist<0.0001*0.0001){
                    flag=false;
                    break;
                }
            }
            output[i_batch][i_point][i] = flag;
        }
    }
}
void launch_filter(
    const torch::Tensor &output,
    const torch::Tensor &x,
    const torch::Tensor &mask) {
  // calculate #threads required
  auto B = output.size(0);
  auto N = output.size(1);
  int64_t count = B*N;
  if (count > 0) {
      AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), “filter”, [&] {
            filter<scalar_t>
            <<<at::cuda::detail::GET_BLOCKS(count, 512), 512, 0, at::cuda::getCurrentCUDAStream()>>>(
              static_cast<int>(count),
              x.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
              mask.packed_accessor32<bool,3,torch::RestrictPtrTraits>(),
              output.packed_accessor32<bool,3,torch::RestrictPtrTraits>());
          C10_CUDA_KERNEL_LAUNCH_CHECK();
      });
  }
}