NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

Home Page:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Doesn't work on wsl2

Pzzzzz5142 opened this issue · comments

Code:

#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <iostream>
using namespace transformer_engine;

void GetSelfFusedAttnForwardWorkspaceSizes(
    size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, DType dtype, bool is_training) {
  constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD;

  auto qkv_shape =
      std::vector<size_t>{batch_size * max_seqlen, 3, num_heads, head_dim};
  auto bias_shape = std::vector<size_t>{1, num_heads, max_seqlen, max_seqlen};
  for (auto i : qkv_shape)
    std::cout << i << " ";
  std::cout << std::endl;

  for (auto i : bias_shape)
    std::cout << i << " ";
  std::cout << std::endl;

  std::cout << batch_size * max_seqlen << " " << num_heads << " " << head_dim
            << std::endl;

  std::cout << batch_size << std::endl;

  auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
  auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype);
  auto cu_seqlens_tensor = TensorWrapper(
      nullptr, std::vector<size_t>{batch_size + 1}, DType::kInt32);
  auto o_tensor = TensorWrapper(
      nullptr,
      std::vector<size_t>{batch_size * max_seqlen, num_heads, head_dim}, dtype);
  auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype);
  auto rng_state_tensor =
      TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);

  auto backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
      bias_type, mask_type, dropout_probability, num_heads, num_heads,
      max_seqlen, max_seqlen, head_dim);

  NVTETensorPack aux_output_tensors;
  nvte_tensor_pack_create(&aux_output_tensors);

  TensorWrapper query_workspace_tensor;
  nvte_fused_attn_fwd_qkvpacked(
      qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(),
      &aux_output_tensors, cu_seqlens_tensor.data(), rng_state_tensor.data(),
      max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
      bias_type, mask_type, query_workspace_tensor.data(), nullptr);
  return;
}

int main() {
  GetSelfFusedAttnForwardWorkspaceSizes(8, 64, 64, 64, 0.7, 0, NVTE_NO_BIAS,
                                        NVTE_PADDING_MASK, DType::kFloat16,
                                        false);
}

Error message:

[cudnn_frontend] INFO:  Filtering engine_configs ...1
[cudnn_frontend] INFO:  Filtered engine_configs ...1

E! CuDNN (v8907) function cudnnBackendFinalize() called:
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: rtc->compile(compilerFlags, this->useNvrtcSassPath, true )
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: ptr.isSupported()
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: engine_post_checks(handle, *engine_iface, engine.getPerfKnobs(), req_size, engine.getTargetSMCount())
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: finalize_internal()
e! Time: 2024-02-26T19:28:39.069264 (0d+0h+0m+1s since start)
e! Process=353402; Thread=353402; GPU=NULL; Handle=NULL; StreamId=NULL.

Also, I've tried cuDNN 9.0. However, it emits similiar error.

[cudnn_frontend] INFO:  Filtering engine_configs ...1
[cudnn_frontend] INFO:  Filtered engine_configs ...1

E! CuDNN (v90000) function cudnnBackendFinalize() called:
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: rtc->loadModule()
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: ptr.isSupported()
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: engine_post_checks(*engine_iface, engine.getPerfKnobs(), req_size, engine.getTargetSMCount())
e!         Error: CUDNN_STATUS_EXECUTION_FAILED; Reason: finalize_internal()
e! Time: 2024-02-26T17:51:46.572696 (0d+0h+0m+0s since start)
e! Process=291453; Thread=291453; GPU=NULL; Handle=NULL; StreamId=NULL.

I've add this line to CMakeLists.txt to make it find libcuda.so properly.

target_link_directories(transformer_engine PUBLIC /usr/lib/wsl/lib)

However, it still cannot run. Is there anything that I miss? Or it is simply non-trivial to make it work on wsl? Thanks!

I haven't tried running on WSL, although I see in this guide that there are some traps related to libcuda.so.

My hunch is that cuDNN can't find the right libcuda.so at run-time, either because it isn't looking within /usr/lib/wsl/lib or because it contains an incorrect version of the file. cuDNN JIT-compiles some kernels using NVRTC, which requires run-time access to CUDA driver functions in libcuda.so. Since libcuda.so may differ between install-time and run-time, e.g. when cross-compiling on a system with no GPUs, NVRTC programs usually have some infrastructure to dynamically find and load libcuda.so instead of relying on the linker as usual (see this comment).

Thanks for the reply. I've install the cuda tool kit in the recommanded way, and I've tried that nvrtc works on wsl2. However, cudnn still cannot run it. Is is possible to get more detailed failing reason since cudnn only emits CUDNN_STATUS_EXECUTION_FAILED. (e.g. emits the error string of nvrtcResult)

Adding this line could make te find the right libcuda.so properly, but it is not work for cudnn.

target_link_directories(transformer_engine PUBLIC /usr/lib/wsl/lib)

Also, adding runpath to libtransformer_engine.so also cannot solve the problem.

Dynamic section at offset 0x104f2138 contains 35 entries:
  Tag        Type                         Name/Value
 0x0000000000000001 (NEEDED)             Shared library: [libcublas.so.12]
 0x0000000000000001 (NEEDED)             Shared library: [libcudart.so.12]
 0x0000000000000001 (NEEDED)             Shared library: [libnvrtc.so.12]
 0x0000000000000001 (NEEDED)             Shared library: [libnvToolsExt.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libcudnn.so.9]
 0x0000000000000001 (NEEDED)             Shared library: [libstdc++.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [libm.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [libgcc_s.so.1]
 0x0000000000000001 (NEEDED)             Shared library: [libc.so.6]
 0x0000000000000001 (NEEDED)             Shared library: [ld-linux-x86-64.so.2]
 0x000000000000000e (SONAME)             Library soname: [libtransformer_engine.so]
 0x000000000000001d (RUNPATH)            Library runpath: [/usr/lib/wsl/lib]

I've also tried to put libcuda.so.1 and libcuda.so to /usr/local/cuda/lib64, still cannot work.

btw, I can have the access to a native linux machine with 3090 cuda 12.1 and it works fine so this issue is not a blocking issue for me.