feat: support WASI-NN streaming extension for NNRPC
hydai opened this issue · comments
Summary
WasmEdge forked the WASI-NN spec with the following extension for a better LLM experience.
Three new functions need to be supported:
- compute_single
- get_output_single
- fini_single
Details
Make the NNRPC support streaming LLM mode.
- Update the wasi-nn proto file: https://github.com/WasmEdge/WasmEdge/blob/master/lib/wasi_nn_rpc/wasi_ephemeral_nn.proto
- Update the NNRPC implementation
Appendix
Hi @hydai, I'm a student at the University of Texas at Austin. I was interested in working on this issue as a part of a project in my virtualization class (CS 360V). I've got the necessary background on LLMs but am new to WasmEdge – do you think it would make sense for me to work on this? Thanks!
@arvganesh Please have a look at Akihiro Suda's talk on NNRPC here (from 2:08)
https://youtu.be/D0D8ufWtILI?si=nTF5nklApkmO7cCu&t=128
If you can build and run NNRPC from source, then you are ready for this. :)
Would you be able to assign this to me? Thanks!
Done. Look forward to your contributions!
Hi @juntao, @hydai! I wanted to outline my understanding of this issue / WasmEdge, let me know if this makes sense before I begin working on it:
Background:
- WasmEdge is a WebAssembly runtime
- Wasi-NN defines a set of APIs and abstractions (Tensors, Backends, Graphs) enabling neural network inference.
- WASI = WebAssembly System Interface, which provides low-level system interfaces for WebAssembly applications.
- Wasi-NN provides an interface to use existing NN-backends (TensorFlow, OpenVINO, etc) for inference.
- https://github.com/second-state/wasmedge-wasi-nn/tree/ggml – this project contains Rust bindings implementing the Wasi-NN spec.
- Wasi-NN RPC, allows WASM applications (clients) to make RPC calls to a server which can perform model inference and return the result. The benefit of this is that the compute accelerators (GPU, etc.) don't necessarily need to be on the client machine.
What I need to do:
- Something similar to 9c60444 but for the three functions listed.
- This will enable "streaming mode" for LLMs, meaning during inference, output tokens can be computed and returned one at a time.
Questions I have:
- What is the purpose of
fini_single
? Why is there not an analogous without the "_single" suffix? - What's the purpose of
HostFuncCaller
in https://github.com/WasmEdge/WasmEdge/blob/master/include/driver/wasi_nn_rpc/wasi_nn_rpcserver/wasi_nn_rpcserver.h#L121? My understanding is that it's a backend-agnostic wrapper that calls the intended function for a specific backend and gets the result for GRPC to send back. Is that correct?
Thanks for your help! I appreciate it.
- What is the purpose of
fini_single
? Why is there not an analogous without the "_single" suffix?
The series of _single
functions has its own lifecycle. It will destroy the context when the fini
calls. The normal function will destroy the context automatically.
- What's the purpose of
HostFuncCaller
in https://github.com/WasmEdge/WasmEdge/blob/master/include/driver/wasi_nn_rpc/wasi_nn_rpcserver/wasi_nn_rpcserver.h#L121? My understanding is that it's a backend-agnostic wrapper that calls the intended function for a specific backend and gets the result for GRPC to send back. Is that correct?
It allows you to call the host function directly. So, you can have a better way to interact with the host function with this simplified interface.
Hi @hydai! I've opened a draft PR adding the features discussed and keep failing this build check. Unfortunately, the only error message is C++ exception with description "std::bad_cast" thrown in the test body.
without any line numbers. I can't seem to figure out the right set of commands to run these tests locally so my ability to iterate quickly is prety low ATM.
These are the commands I'm using to build and run tests on MacOS. This set of commands runs tests, just not the ones for WasiNN.
cmake -Bbuild -GNinja -DWASMEDGE_BUILD_WASI_NN_RPC=On -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_TESTS=ON`.
DYLD_LIBRARY_PATH=$(pwd)/lib/api ctest
Do you have any thoughts on what I'm doing wrong here?
Please follow the same command running in the CI to reproduce the issue.
It seems like the tests you added are triggering this exception.
There is an error message:
unknown file: Failure
C++ exception with description "std::bad_cast" thrown in the test body.
You should check if there is anything related to the files / any casting issue.