bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.

Home Page:https://huggingface.co/docs/bitsandbytes/main/en/index

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A bit of confusion about the function kdequant_mm_int32_fp16().

ahqzy opened this issue · comments

System Info

I am studying the function kdequant_mm_int32_fp16() and encountered a confusing issue. I constructed a matrix with a shape of [3, 96] and values ranging from 0 to 95 as follows:
[0, 1, ..., 30, 31, 32, 33, ..., 62, 63, 64, 65, ..., 94, 95
96, 97, ............................................................................191,
192, 193, ...................................................................., 287]

When I added print statements in the kdequant_mm_int32_fp16() function, I found that for the value 32, the calculated column index is 0 and the row index is 1. According to my understanding, since the matrix has 96 columns, shouldn't 32 be located at row 0, column 32? Similarly, for the value 96, the calculated column index is 32 and the row index is 0. Shouldn't 96 be located at row 1, column 0? The print statement was added at line 145 of the code:
1710231919068_5129B6A0-F123-474c-A321-B74F7001C33B

Log Information:
1710232225532_6052EC82-2AFE-4f93-95B7-75FC662B9FC9

Is it possible that the calculated column and row indices in the code do not directly correspond to the rows and columns of the matrix? How should I interpret this?

Reproduction

The test code is:

#include <iostream>
#include <iomanip>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cub/device/device_scan.cuh>
#include <cstdio>
#include <random>

/*
 nvcc -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80  -gencode arch=compute_86,code=sm_86 
 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 
 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70   -o dequant_int32_fp16  dequant_int32_fp16.cu 
 -L /usr/local/cuda/lib64  -lcudart -lcublas -lcublasLt -lcusparse
*/

#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)

#define CUDA_CHECK_RETURN(value) {                      \
  cudaError_t _m_cudaStat = value;                    \
  if (_m_cudaStat != cudaSuccess) {                   \
    fprintf(stderr, "Error %s at line %d in file %s\n",         \
        cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__);   \
    exit(1);                              \
  } }

#include <iostream>
#include <cuda_runtime.h>

template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n)
{
  // Strategy: To dequantize we need to load col/row statistics. This can be very expensive
  // since different row/col stats need to be loaded with each thread.
  // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
  // and would lead to low global load utilization.
  // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
  // for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
  // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
  // This allows for efficient row/col loading from shared memory within the tile.
  // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
  // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
  // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
  // shared memory loads.

  // data is in 32 column-tile major with tile width 32 columns and numRows rows
  // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
  // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
  // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
  // C2. Compute normalization values and store col values in register
  // S1. Store C1 into 16-bit output
  // S2. Store col/row statistics of new buffer in shared memory

  // We allow for sub-tiles to span multiple col32 tiles. This is okay
  // since the items per thread only rely on a single column statistic.

  //printf("----- %d %d %d %d %d %d %d %d\n", 
  //        A[0], A[1], A[2], A[3], A[4], A[5], A[6], A[7]);

  const int n_out = numRows*numCols;


  int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
  // we have tiles of size numRows*32, thus col only increases every numRows
  // num_row_tiles is the tiles after which the column increases by 32
  // blockIdx.x is the index of the current tile
  int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));

  // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
  int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);

  //printf("num_row_tiles = %d, col = %d, base_row = %d, blockDim.x = %d\n", num_row_tiles, col, base_row, blockDim.x);

  // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
  // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
  // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
  // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
  // 1024*1024/(128*32) = 256 tiles
  // 256 tiles are 256*128*32/4 = 256*1024 threads

  // 1. Figure out how index relates to the start of the sub-tile
  // 2. Each thread < SUBTILE_ROWS calculates row index
  // 3. Load striped and store in shared memory

  int local_values[ITEMS_PER_THREAD];
  half local_output[ITEMS_PER_THREAD];
  float local_rowStats[ITEMS_PER_THREAD];
  __shared__ float smem_rowStats[SUBTILE_ROWS];

  typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
  typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
  __shared__ typename LoadInt32::TempStorage loadint32;
  __shared__ typename ExchangeInt32::TempStorage exchangeint32;

  // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
  float colStat = col >= numCols ? 0.0f : colStats[col];

  float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);

  // no block loads for rows for now -- keep it simple
  for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
  {
    // todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
    int row = (base_row+j) % numRows; // wrap around
    // each warp accesses the same element, for four consequitive elements
    // todo: update description about striped shared memory, it is not needed
    // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
    smem_rowStats[j] = rowStats[row];
  }

  __syncthreads();

  // each block processes SUBTILE_ROWS*32 elements
  const int items_per_load = THREADS*ITEMS_PER_THREAD;  
  const int rows_per_load = items_per_load/32;  

  int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
  int row_offset = 0;
  // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
  int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);

  int cnt = 0;

  for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
  {
    cnt++;
    int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
    int valid_items = valid_rows*32;
    if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
      break;

    // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
    LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);

    ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);

    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
    {
      local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
    }  

    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
    {
      local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
      printf("j = %d, get local_values:%d, blockIdx.x=%d, col=%d, threadIdx.x=%d, local_rowStats:%f, colStat:%f\n",
              j, local_values[j], blockIdx.x, col, threadIdx.x, local_rowStats[j], colStat);
    }  
      //absmax_col = fmax(fabsf(local_output[j]), absmax_col);

    // we store data in row major
    // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
    // so that each thread holds ITEMS_PER_THREAD consecutive items for each row
    // this way throughput into storage is increased by a factor of ~2x
    // for now we use a simple store
    #pragma unroll ITEMS_PER_THREAD
    for(int j = 0; j < ITEMS_PER_THREAD; j++)
    {
      int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
      if(outIdx< n_out && col < numCols)
      {
        out[outIdx] = local_output[j];
      }
    }

    row_offset += rows_per_load;
  }
}

int fill_up_to_nearest_multiple(int value, int multiple)
{
  return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}

void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
{
  int threads = 512;
  int tileCols = fill_up_to_nearest_multiple(numCols, 32);
  int n = numRows*tileCols;
  int subtile_rows = 128;
  int tilesize = 32*subtile_rows;
  int num_blocks = numRows/subtile_rows;
  num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
  num_blocks = num_blocks*(tileCols/32);

  printf("===== num_blocks = %d, tileCols = %d\n", num_blocks, tileCols);

  assert(threads <= tilesize);

  kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
  CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

void printMatrix(const half* matrix, int numRows, int numCols) 
{
  for (int i = 0; i < numRows; ++i) 
  {
    for (int j = 0; j < numCols; ++j)
    {
      printf("out[%d][%d]: %f\n", i, j,  __half2float(matrix[i * numCols + j]));
    }

    printf("\n");
  }
}

int main() {
    int numRows = 3;
    int valid_numCols = 96;

    int numCols = 96; 

    int A[numRows*numCols];

    for (int i = 0; i < numRows*numCols; i++)
    {
      A[i] = i;
    }

    float rowStats[numRows];
    for (int i = 0; i < numRows; i++)
    {
      rowStats[i] = 0.001*i + 1.0;
    }

    float colStats[valid_numCols];
    for (int i = 0; i < valid_numCols; i++)
    {
      colStats[i] = 0.001*i + 2.0;
    } 

    float new_row_stats[numRows];
    for (int i = 0; i < numRows; i++)
    {
      new_row_stats[i] = 0.001*i + 3.0;
    }

    float new_col_stats[valid_numCols];
    for (int i = 0; i < valid_numCols; i++)
    {
      new_col_stats[i] = 0.001*i + 4.0;
    } 

    int* d_A;
    float *d_rowStats, *d_colStats, *d_newRowStats, *d_newColStats;
    half *d_out;

    CUDA_CHECK_RETURN(cudaMalloc((void**)&d_A, sizeof(int) * numRows*numCols));
    CUDA_CHECK_RETURN(cudaMalloc((void**)&d_rowStats, sizeof(float) * numRows));
    CUDA_CHECK_RETURN(cudaMalloc((void**)&d_colStats, sizeof(float) * valid_numCols));
    CUDA_CHECK_RETURN(cudaMalloc((void**)&d_newRowStats, sizeof(float) * numRows));
    CUDA_CHECK_RETURN(cudaMalloc((void**)&d_newColStats, sizeof(float) * valid_numCols));
    CUDA_CHECK_RETURN(cudaMalloc((void**)&d_out, sizeof(half) * numRows * valid_numCols));


    CUDA_CHECK_RETURN(cudaMemcpy(d_A, A, sizeof(int)* numRows * numCols, cudaMemcpyHostToDevice));
    CUDA_CHECK_RETURN(cudaMemcpy(d_rowStats, rowStats, sizeof(float) * numRows, cudaMemcpyHostToDevice));
    CUDA_CHECK_RETURN(cudaMemcpy(d_colStats, colStats, sizeof(float) * valid_numCols, cudaMemcpyHostToDevice));
    CUDA_CHECK_RETURN(cudaMemcpy(d_newRowStats, new_row_stats, sizeof(float) * numRows, cudaMemcpyHostToDevice));
    CUDA_CHECK_RETURN(cudaMemcpy(d_newColStats, new_col_stats, sizeof(float) * valid_numCols, cudaMemcpyHostToDevice));

    dequant_mm_int32_fp16(d_A, d_rowStats, d_colStats, d_out, d_newRowStats, d_newColStats, NULL, numRows, valid_numCols);

    half* h_out = new half[numRows * valid_numCols];

    CUDA_CHECK_RETURN(cudaMemcpy(h_out, d_out, sizeof(half) * numRows * valid_numCols, cudaMemcpyDeviceToHost));

    //std::cout << "Output Matrix:" << std::endl;
    //printMatrix(h_out, numRows, valid_numCols);

    cudaFree(d_out);
    //delete[] h_out;

    return 0;
}

Expected behavior

please help me, thx!

Note the explanation in this comment: // each block processes SUBTILE_ROWS*32 elements.

Do you see the results you expect when you look at the full output?