rapidsai / raft

RAFT contains fundamental widely-used algorithms and primitives for machine learning and information retrieval. The algorithms are CUDA-accelerated and form building blocks for more easily writing high performance applications.

Home Page:https://docs.rapids.ai/api/raft/stable/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] Compile the `template` Fails with RTX 4090

JieFengWang opened this issue · comments

Describe the bug

After installed RAFT via mamba. I try to compile the template on my machine. the compile fails with error info :
"ptxas error : Value of threads per SM for entry ZN4raft9neighbors12experimental10nn_descent6detail17local_join_kernelIiNS3_12InternalID_tIiEEEEvPKT_S9_PK4int2S9_S9_SC_iPK6__halfiPT0_PfiPiSI is out of range. .minnctapersm will be ignored
ptxas fatal : Ptx assembly aborted due to errors
gmake[2]: *** [_deps/raft-build/CMakeFiles/raft_objs.dir/build.make:1813: _deps/raft-build/CMakeFiles/raft_objs.dir/src/raft_runtime/neighbors/cagra_build.cu.o] Error 255"

  • device RTX 4090
  • ubuntu 22.04
  • nvcc 12.2
  • g++ 11.4
  • cmake 3.28

Steps/Code to reproduce bug
Follow this guide http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports to craft a minimal bug report. This helps us reproduce the issue you're having and resolve the issue more quickly.

cd template
./build.sh

Expected behavior
A clear and concise description of what you expected to happen.

  • I wanna compile the template with success

Environment details (please complete the following information):

  • Environment location: [Bare-metal]
  • Method of RAFT install: [mamba]

Additional context
Add any other context about the problem here.

  • I repeat the same steps with RTX 1080Ti, RTX 3090, all success, only RTX 4090 fails.

Still error: when compile the template, it shows "ptxas error : Value of threads per SM for entry ZN4raft9neighbors12experimental10nn_descent6detail17local_join_kernelIiNS3_12InternalID_tIiEEEEvPKT_S9_PK4int2S9_S9_SC_iPK6__halfiPT0_PfiPiSI is out of range. .minnctapersm will be ignored" .

This looks like the "InternalID_t" out of range?

struct InternalID_t;

// InternalID_t uses 1 bit for marking (new or old).
template <>
class InternalID_t<int> {
 private:
  using Index_t = int;
  Index_t id_{std::numeric_limits<Index_t>::max()};

 public:
  __host__ __device__ bool is_new() const { return id_ >= 0; }
  __host__ __device__ Index_t& id_with_flag() { return id_; }
  __host__ __device__ Index_t id() const
  {
    if (is_new()) return id_;
    return -id_ - 1;
  }
  __host__ __device__ void mark_old()
  {
    if (id_ >= 0) id_ = -id_ - 1;
  }
  __host__ __device__ bool operator==(const InternalID_t<int>& other) const
  {
    return id() == other.id();
  }
};

UPDATE:
I remove all mamba / miniforge environment, and delete all raft source code and the compiled / installed libs from my ubuntu.
ONLY the template folder left in my machine. I run the build shell via ./build.sh. compile still fails with the same error. I then modify the 398-th sentence of the nn_descent.cuh (@ /path/to/template/build/_deps/raft-src/cpp/include/raft/neighbors/detail/nn_descent.cuh) to

constexpr int BLOCK_SIZE                = 256; ///512;

I reduce the BLOCK_SIZE from 512 to 256. The bug is gone.

BUT, the nn_descent is even slower than RTX 3090, and the knn graph quality (recall) is nearly zero.

OK, fixed.

change the 694-th line of code of /path/to/template/build/_deps/raft-src/cpp/include/raft/neighbors/detail/nn_descent.cuh to

#if (__CUDA_ARCH__) == 750 || (__CUDA_ARCH__) == 860 || (__CUDA_ARCH__) == 890

the bug fixed.

  • since I see a comment here
// launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4
// Per
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications,
// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048
// For architectures 750 and 860, the values for MAX_RESIDENT_THREAD_PER_SM
// is 1024 and 1536 respectively, which means the bounds don't work anymore

and i find RTX4090's MAX_RESIDENT_THREAD_PER_SM is also 1536, and its arch is 89. So i add || (__CUDA_ARCH__) == 890 to this

Since this PR has been merged, this issue will be closed.