deislabs / wasi-nn-onnx

Experimental ONNX implementation for WASI NN.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[C API]: Compile ONNX bindings with GPU support

radu-matei opened this issue · comments

This would add the CUDA and DirectML headers and pull the appropriate shared object.
See nbigaouette/onnxruntime-rs#57

As I am not currently on a CUDA-enabled machine, labeling this as help wanted.

Using this branch of the ONNX Rust bindings - https://github.com/radu-matei/onnxruntime-rs/tree/cuda, the following patch works with CUDA 10.2:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

Patch:

From 4315f65e7fd816f0568f4f12ee58640a61e6610b Mon Sep 17 00:00:00 2001
From: Radu M <root@radu.sh>
Date: Sun, 27 Jun 2021 11:12:39 +0000
Subject: [PATCH] Trying to enable CUDA

---
 Cargo.lock                                       | 6 ++++--
 crates/wasi-nn-onnx-wasmtime/Cargo.toml          | 2 +-
 crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs | 1 +
 3 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 5fc7a4d..cf77155 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,5 +1,7 @@
 # This file is automatically @generated by Cargo.
 # It is not intended for manual editing.
+version = 3
+
 [[package]]
 name = "addr2line"
 version = "0.15.2"
@@ -1489,7 +1491,7 @@ checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56"
 [[package]]
 name = "onnxruntime"
 version = "0.0.12"
-source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=owned-session#5f47f47b24793c0d0fbb314e854cc04395b9108f"
+source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=cuda#2e5a8649def1d6cdcdd02018f9ae7c415d5f6c25"
 dependencies = [
  "lazy_static",
  "ndarray",
@@ -1501,7 +1503,7 @@ dependencies = [
 [[package]]
 name = "onnxruntime-sys"
 version = "0.0.12"
-source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=owned-session#5f47f47b24793c0d0fbb314e854cc04395b9108f"
+source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=cuda#2e5a8649def1d6cdcdd02018f9ae7c415d5f6c25"
 dependencies = [
  "flate2",
  "tar",
diff --git a/crates/wasi-nn-onnx-wasmtime/Cargo.toml b/crates/wasi-nn-onnx-wasmtime/Cargo.toml
index a307e12..8979c8d 100644
--- a/crates/wasi-nn-onnx-wasmtime/Cargo.toml
+++ b/crates/wasi-nn-onnx-wasmtime/Cargo.toml
@@ -9,7 +9,7 @@ anyhow = "1.0"
 byteorder = "1.4"
 log = { version = "0.4", default-features = false }
 ndarray = "0.15"
-onnxruntime = { git = "https://github.com/radu-matei/onnxruntime-rs", branch = "owned-session", optional = true }
+onnxruntime = { git = "https://github.com/radu-matei/onnxruntime-rs", branch = "cuda", optional = true }
 thiserror = "1.0"
 tract-data = "0.14"
 tract-linalg = "0.14"
diff --git a/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs b/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
index 1fd2d7e..71d4a16 100644
--- a/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
+++ b/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
@@ -141,6 +141,7 @@ impl WasiEphemeralNn for WasiNnOnnxCtx {
             .build()?;
         let session = environment
             .new_owned_session_builder()?
+            .use_cuda()?
             .with_optimization_level(GraphOptimizationLevel::All)?
             .with_model_from_memory(model_bytes)?;
         let session = OnnxSession::with_session(session)?;
-- 
2.17.1

Environment:

export PATH="/usr/local/cuda-10.2/bin:$PATH"
export LD_LIBRARY_PATH="/usr/local/cuda-10.2/lib64:$LD_LIBRARY_PATH"
# export LD_LIBRARY_PATH="/root/projects/onnx/onnxruntime-linux-x64-gpu-1.6.0/lib:$LD_LIBRARY_PATH"
export LD_PRELOAD=""
export ORT_USE_CUDA=1
export ORT_STRATEGY=download

CUDA 10.2 might be hitting this issue - microsoft/onnxruntime#5957

In any case, the performance is significantly worse than it was expected with a Tesla P100, and I suspect it has to do with the CUDA version.
In any case, more testing needs to be done on other hardware as well, but updating the ONNX version to one that uses CUDA 11 (1.7+) could (?) solve this issue - #25

For Windows, we should also try compiling with DirectML support - https://www.onnxruntime.ai/docs/reference/execution-providers/DirectML-ExecutionProvider.html

I've created a PR nbigaouette/onnxruntime-rs#87 with CUDA 11 for ONNX 1.7 based on nbigaouette/onnxruntime-rs#78

I think it is what you're looking for testing. Feel free to review the branch and point out any key issues :)