ggerganov / ggml

Tensor library for machine learning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Conv2D kernel CuBLAS implementation - need feedback

FSSRepo opened this issue · comments

Context

In the last few days, I've been working on creating a Conv2D kernel for the "sd.cpp" project. I already have the kernel created, but when trying to implement it in "ggml," I've encountered a limitation where the data passed to the GPU must be in FP32 format, but the current CPU implementation of Conv2D requires FP16.

Here is my repo of the results: https://github.com/FSSRepo/ggml-cuda-experiments

Trying to implement the kernel in GGML CUDA:

Working in the file ggml-cuda.cu:

Add cuda kernels:

static __global__ void gemm_f16_f32(const half  *x,const half  *y, float *dst, int N, int M, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < N && col < K) {
        float sum = 0.0f;
        for (int i = 0; i < M; ++i) {
            sum += __half2float(x[row * M + i]) * __half2float(y[col * M + i]);
        }
        dst[row * K + col] = sum;
    }
}

static  __global__ void img2col_f32_f16(const float* x, half* dst, int nb12, int nb13, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
    int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
	int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
    __syncthreads();
    if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
        int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
        int offset_src = threadIdx.x * nb13 +  blockIdx.x * nb12;
        dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
    }
}

// launchers
static void img2col_f32_f16_cuda(float* x, half* dst,
    int OC, int OH,
    int IW, int IH,
    int OW, int IC,
    int KH, int KW, int N,
    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {

    int nb11 = IW;
    int nb12 = nb11 * IH; // nb[1] * ne[1]
    int nb13 = nb12 * IC; // nb[2] * ne[2]

    int CHW = IC * KH * KW;
    dim3 block_nums(IC, OH, OW);
    dim3 block_dims(N, KH, KW);
    img2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, nb12, nb13, IW, IH, CHW, s0, s1, p0, p1, d0, d1);
}

// GEMM
static void gemm_f16_f32_cuda(half* x, half* y, float* dst, int OC, int OH, int OW,int IC, int KH, int KW, int N, cudaStream_t stream) {
        int m = OC;
        int n = OH * OW;
        int k = IC * KH * KW;
        for(int i = 0; i < N; i++) {
            dim3 block_dims(16, 16);
            dim3 block_nums((n + block_dims.x - 1) / block_dims.x, (m + block_dims.y - 1) / block_dims.y);
            gemm_f16_f32<<<block_nums, block_dims, 0, stream>>>(x, y + i * m * k, dst + i * m * n, m, k, n);
        }
}

Add op functions in ggml_cuda_compute_forward:

case GGML_OP_CONV_2D_STAGE_0:
            if (!any_on_device) {
                return false;
            }
            func = ggml_cuda_conv2d_stage_0;
            break;
         case GGML_OP_CONV_2D_STAGE_1:
            if (!any_on_device) {
                return false;
            }
            func = ggml_cuda_conv2d_stage_1;
            break;

Creating ggml_cuda_conv2d_stage_0 and ggml_cuda_conv2d_stage_1 cuda functions:

void ggml_cuda_conv2d_stage_0(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
    GGML_ASSERT(src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16);
    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_conv2d_stage_0, false, true);
}

void ggml_cuda_conv2d_stage_1(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
    GGML_ASSERT(src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32);
    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_conv2d_stage_1, false, true);
}

Creating cuda ops ggml_cuda_op_conv2d_stage_0 and ggml_cuda_op_conv2d_stage_1, I need feedback in this section, what should I do?:

inline void ggml_cuda_op_conv2d_stage_0(
    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
    cudaStream_t & cudaStream_main) {

    GGML_ASSERT(src0_ddf_i != nullptr);
    GGML_ASSERT(dst_ddf_i != nullptr);

    const int64_t i01_diff = i01_high - i01_low;

    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];

    const int64_t N = src1->ne[3];
    const int64_t IC = src1->ne[2];
    const int64_t IH = src1->ne[1];
    const int64_t IW = src1->ne[0];

    const int64_t OC = src0->ne[3];
    // const int64_t IC = ne02;
    const int64_t KH = src0->ne[1];
    const int64_t KW = src0->ne[0];

    const int64_t OH = dst->ne[2];
    const int64_t OW = dst->ne[1];

    // NEED FEEDBACK
    // dst_ddf_i is a float 32 pointer, but the op requires float 16

    // compute
    img2col_f32_f16_cuda(src0_ddf_i, dst_ddf_i /** SHOULD BE half type ***/,
        OC, OH, IW, IH, OW, IC, KH, KW, N, s0, s1, p0, p1, d0, d1, cudaStream_main);

    (void) src1;
    (void) src0_ddq_i;
    (void) src1_ddf_i;
    (void) i1;
}

inline void ggml_cuda_op_conv2d_stage_1(
    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
    cudaStream_t & cudaStream_main) {

    GGML_ASSERT(src0_ddf_i != nullptr);
    GGML_ASSERT(dst_ddf_i != nullptr);

    const int64_t i01_diff = i01_high - i01_low;

    const int N = src1->ne[3];
    const int OH = src1->ne[2];
    const int OW = src1->ne[1];

    const int OC = src0->ne[3];
    const int IC = src0->ne[2];
    const int KH = src0->ne[1];
    const int KW = src0->ne[0];

    // NEED FEEDBACK
    // src0_ddf_i is a float 32 pointer, but the op requires float 16
   // src1_ddf_i is a float 32 pointer, but the op requires float 16

    gemm_f16_f32_cuda(
        src0_ddf_i /** SHOULD BE half type ***/, src1_ddf_i /** SHOULD BE half type ***/,
        dst_ddf_i, OC, OH, OW, IC, KH, KW, N, cudaStream_main);

    (void) src1;
    (void) src0_ddq_i;
    (void) src1_ddf_i;
    (void) i1;
}

Compiler Error:

D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\ggml-cuda.cu(5657): error : argument of type "float *" is incompatible with 
parameter of type "half *" [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
        img2col_f32_f16_cuda(src0_ddf_i, dst_ddf_i ,
                                         ^

D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\ggml-cuda.cu(5687): error : argument of type "float *" is incompatible with  
parameter of type "half *" [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
            src0_ddf_i , src1_ddf_i ,
            ^

D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\ggml-cuda.cu(5687): error : argument of type "float *" is incompatible with  
parameter of type "half *" [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
            src0_ddf_i , src1_ddf_i ,

I think you have an outdated version of ggml-cuda. ggml_cuda_op doesn't exist anymore, the equivalent is ggml_cuda_op_flatten, and the float * _dd pointers aren't necessarily float, they are just device pointers of the tensor data. For example, look at the implementation of GGML_OP_ADD. Notice that the _dd pointers are just cast to the type of the tensor:

ggml/src/ggml-cuda.cu

Lines 6916 to 6918 in 8b5c564

static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
}

ggml/src/ggml-cuda.cu

Lines 5882 to 5901 in 8b5c564

inline void ggml_cuda_op_add(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
} else {
GGML_ASSERT(false);
}
(void) src1;
(void) dst;
}

It would be good to reuse the current matrix multiplication kernels instead of adding another one.

Compiling the latest version of ggml-cublas I get this error:

D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\ggml-cuda.cu(7772): error : expected an expression [D:\proyectos\cpp-project
s\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
        *cuda_backend = (ggml_backend){
                                      ^
ggml_backend_t cuda_backend = new ggml_backend;
    *cuda_backend = (ggml_backend){
        /* .interface = */ cuda_backend_i,
        /* .context   = */ ctx
    };

In the version of the code that you mentioned

You need to update the rest of ggml as well.

You need to update the rest of ggml as well.

I had updated full ggml repository

You are using the old headers at least. These errors indicate that ggml-backend.h is not included, but it should be included in ggml-cuda.h. So you ggml-cuda.h is likely outdated.

Edit: may also be an issue with MSVC.

Does this fix the issue?

--- a/src/ggml-cuda.cu
+++ b/src/ggml-cuda.cu
@@ -7768,8 +7768,7 @@ ggml_backend_t ggml_backend_cuda_init() {

     ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda;

-    ggml_backend_t cuda_backend = new ggml_backend;
-    *cuda_backend = (ggml_backend){
+    ggml_backend_t cuda_backend = new ggml_backend {
         /* .interface = */ cuda_backend_i,
         /* .context   = */ ctx
     };

Does this fix the issue?

Yes, but I get these errors:

D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(34,38): error C2236: unexpected token 'struct
'. check if you forget ';' [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(34,47): error C2332: 'struct': falta el nombr
e de etiqueta [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(34,47): error C2027: use of type'<unnamed-t
ag>' undefined [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(45,52): error C2236: token inesperado 'struct
'. Compruebe si olvidó un ';' [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(45,61): error C2332: 'struct': missing tag name [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(96,31): error C2236: token inesperado 'struct
'. Compruebe si olvidó un ';' [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(96,40): error C2332: 'struct': falta el nombr 
e de etiqueta [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]
D:\proyectos\cpp-projects\stable-diffusion.cpp\ggml\src\..\include\ggml\ggml-backend.h(96,40): error C2027: uso del tipo '<unnamed-t 
ag>' sin definir [D:\proyectos\cpp-projects\stable-diffusion.cpp\build\ggml\src\ggml.vcxproj]

Apparently, MSVC defines interface to struct. I am looking into it, but adding this to ggml-backend.h should fix it:

#ifdef interface
#undef interface
#endif

Apparently, MSVC defines interface to struct. I am looking into it, but adding this to ggml-backend.h should fix it:

#ifdef interface
#undef interface
#endif

Works!

I haven't had the chance to test if the kernel addition works because the latest version of ggml doesn't have Conv2D Stage 0 and Stage 1 implemented. Trying to reimplement everything in the latest version of ggml didn't work for me, as the program just crashes. I'll have to wait for the developer of stable-diffusion.cpp to update the ggml version, and then I can add the CUDA implementation.

For that reason, I first tried to implement it in the old version of ggml-cuda. Nevertheless, I'm not sure if enabling GGML_CUBLAS sets the CUDA backend.

How to perform the compute in GPU in the old ggml-cuda? @slaren

I'm trying this:

int KW = 3, KH = 3, IC = 640, OC = 640;
    int IW = 32, IH = 48, /* IC = 640 */ N = 1;
struct ggml_tensor* ha = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, KW, KH, IC, OC);
    memcpy(ha->data, hadata, KW * KH * IC * OC * sizeof(uint16_t));

    struct ggml_tensor* b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, IW, IH, IC, N);
    memcpy(b->data, bdata, IW * IH * IC * N * sizeof(float));

    struct ggml_tensor* result = ggml_conv_2d(ctx, ha, b, s0, s1, p0, p1, d0, d1);
    result->backend = GGML_BACKEND_GPU; //  This perform the op in gpu???

    ggml_set_name(result, "Result Tensor");
    struct ggml_cgraph gf = ggml_build_forward(result);

    ggml_graph_compute_with_ctx(ctx, &gf, 6);

    const float* ref = (float*)(result->data);

    printf("conv2d:\n%.2f %.2f %.2f %.2f\n%.2f %.2f %.2f %.2f\n",
                ref[0], ref[1], ref[2],
                ref[3], ref[4], ref[5],
                ref[6], ref[7]);

Error:

ggml_cuda_op:
ne03: 640 ne13: 1
GGML_ASSERT: D:\proyectos\cpp-projects\ggml-test\ggml\ggml-cuda.cu:5773: ne03 == ne13

The backend api is so confusing!

There are two ways to use the CUDA backend, with the old API or the ggml-backend API:

The old API is more flexible and supports things such as only offloading part of a model to VRAM, and mixing CPU and GPU computation for operations that aren't supported yet in CUDA, and partially offloaded models. I think only llama.cpp uses it fully currently. The starcoder-mmap example also supports offloading the weights to VRAM, but I think it is missing other details that will cause it to be much slower than llama.cpp. Basically, this was created to support llama.cpp, and support in other ggml projects is very poor.

ggml-backend is a recent addition that intends to provide a common API to use all the CPU and GPU backends. Currently, it only supports fully offloading all the computation to the GPU, and that requires all the weights to be stored VRAM, and all the operations must be implemented in the CUDA backend. In the future it will be extended to support partial offloading, but it is not ready yet. You can find an example of how to use it in the gpt-2 example.

If you have enough VRAM to fully offload the model, you should try ggml-backend, it will be much easier to use than the old API.

If you have more questions add a new reply, I don't get pinged when you edit your comments.