I can not understand the `cublasGemmStridedBatchedEx` call in the `attention_forward`
echosprint opened this issue · comments
in attention_forward we need compute the Q * K, which is (B, NH, T, HS) row major stored?
but cublasGemmStridedBatchedEx would treat the matrix as (HS, T) because cublas is column major default,
so why we send (T, T, HS) for the (M,N,K) parameters, instead of (HS, T, HS)
cublasCheck(cublasGemmStridedBatchedEx(cublas_handle,
CUBLAS_OP_T, CUBLAS_OP_N,
T, T, HS, &alpha,
k, CUBLAS_LOWP, HS, T * HS,
q, CUBLAS_LOWP, HS, T * HS,
&beta, preatt, CUBLAS_LOWP, T, T * T,
B * NH, cublas_compute, CUBLAS_GEMM_DEFAULT));
can anyone explain?