NVIDIA / cub

[ARCHIVED] Cooperative primitives for CUDA C++. See https://github.com/NVIDIA/cccl

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BlockRadixRankMatch produces invalid results when warp size does not divide block size

Snektron opened this issue · comments

As the title suggests, when the device warp size does not divide the block size exactly BlockRadixRankMatch may produce invalid results. This seems to be because this algorithm uses warp-level instructions which do not take the actual launch bounds into account. In specific, this call to the match.any emulation also returns set bits for lanes that do not participate in the warp:

uint32_t peer_mask = MatchAny<RADIX_BITS>(digit);

This code reproduces the bug for me, on both Titan V and RTX 3090:

#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_radix_rank.cuh>
#include <cub/block/radix_rank_sort_operations.cuh>

#include <vector>
#include <ostream>

template<unsigned block_size, unsigned items_per_thread>
__global__ __launch_bounds__(block_size) void kernel(const unsigned* keys, int* ranks) {
    constexpr unsigned items_per_block = block_size * items_per_thread;
    const unsigned tid = threadIdx.x;
    const unsigned block_offset = blockIdx.x * items_per_block;

    unsigned thread_keys[items_per_thread];
    cub::LoadDirectWarpStriped(tid, keys + block_offset, thread_keys);

    cub::BFEDigitExtractor<unsigned> digit_extractor(0, 5);
    int thread_ranks[items_per_thread];

    using Ranker = cub::BlockRadixRankMatch<block_size, 5, false>;
    __shared__ typename Ranker::TempStorage storage;

    Ranker ranker(storage);
    ranker.RankKeys(thread_keys, thread_ranks, digit_extractor);

    cub::StoreDirectWarpStriped(tid, ranks + block_offset, thread_ranks);
}

int main() {
    constexpr unsigned size = 2; // Not a multiple of the warp size.
    std::vector<unsigned> keys = {0, 1};

    unsigned* d_keys;
    cudaMalloc(&d_keys, size * sizeof(unsigned));

    int* d_ranks;
    cudaMalloc(&d_ranks, size * sizeof(int));

    cudaMemcpy(d_keys, keys.data(), size * sizeof(unsigned), cudaMemcpyHostToDevice);

    (kernel<size, 1>)<<<1, size>>>(d_keys, d_ranks);

    cudaDeviceSynchronize();

    std::vector<int> ranks(size);
    cudaMemcpy(ranks.data(), d_ranks, size * sizeof(int), cudaMemcpyDeviceToHost);

    for (int i = 0; i < size; ++i) {
        std::cout << "[" << i << "] " << keys[i] << " expected=" << i << " actual=" << ranks[i] << std::endl;
    }

    cudaFree(d_keys);
    cudaFree(d_ranks);
}

output:

[0] 0 expected=0 actual=0
[1] 1 expected=1 actual=31

(Note that since this includes undefined data im not sure if the above always reproduces it.)

@Snektron thank you for reporting this! You are right, the issue is in the MatchAny. Here's a translation of this function into C++ "language":

template <int LABEL_BITS>
__device__ unsigned MatchAnyC(unsigned int label) {
  unsigned int retval;

  for (int BIT = 0; BIT < LABEL_BITS; ++BIT) {
    unsigned current_bit = 1 << BIT;

    // mask has a single bit from label
    unsigned mask = label & current_bit;

    // check if the bit is set or not
    unsigned p = mask == current_bit;

    // vote to find threads that has the bit set
    mask = __ballot_sync(0xFFFFFFFF, p);

    // if the label bit was not set
    if (!p) {
      // revert mask to find other threads with unset bit
      mask = ~mask;
    }

    // Remove peers who differ
    retval = (BIT == 0) ? mask : retval & mask;
  }

  return retval;
}

Apparently, exited threads don't set the flag. After the ballot, we think that zero in the mask indicates that thread has predicate evaluated to false. When the mask is reverted for threads that have BIT unset, these threads appear as having unset bit instead as exiting threads. Possible solution might be:

template <int LABEL_BITS>
__device__ unsigned MatchAnyF(unsigned int label, unsigned int warp_threads) {
  return MatchAny<LABEL_BITS>(label) & ~(~0 << warp_threads);
}

But instead, I'd like to take time to experiment with specialized volta+ instructions. It's possible to replace the whole facility to something like:

template <int LABEL_BITS>
__device__ unsigned MatchAnyV(unsigned int label) {
  return __match_any_sync(0xFFFFFFFF, label & ~(~0 << LABEL_BITS));
}