RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite?
josephrocca opened this issue · comments
Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the .pth
model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.
Exporting to ONNX is something that I've been tinkering with and I can report that the 169m RWKV-4 model does run in browser. Here's my code: https://github.com/AXKuhta/RWKV-LM/tree/onnx
There's two things missing:
- JavaScript implementation of the tokenizer
- JavaScript implementation of sample_logits().
Running python -i -u export_onnx.py
and then rnn_export()
will export the model as rwkw.onnx, which can then be tested with test_onnx.py and loaded from index.html. The demo in index.html uses greedy sampling and you just sorta have to visit https://goose.ai/tokenizer in order to encode/decode the text. It works using the wasm backend, but unfortunately throws an error if you try the webgl backend.
Exporting to ONNX is something that I've been tinkering with and I can report that the 169m RWKV-4 model does run in browser. Here's my code: https://github.com/AXKuhta/RWKV-LM/tree/onnx
Great work :)
Did you get this error with webgl? cannot resolve operator 'Max' with opsets: ai.onnx v13
You can remove RWKV_HEAD_QK and RWKV-ffnPre which are not required for Pile models, and probably that will fix it.
p.s. upgrade onnxruntime to latest version and then you can test CUDAExecutionProvider in python. I think you might be using an older onnxruntime because all new versions require explicitly setting providers when initializing InferenceSession().
@AXKuhta Nice! I got a web demo going here (for 169m and 430m):
- Demo: https://josephrocca.github.io/rwkv-v4-web/demo/
- Code: https://github.com/josephrocca/rwkv-v4-web
But it seems like something is going wrong - the model isn't "coherent" in using the context. For example, if you prompt the 430m model with "The capital of France is" it continues with "first of the, the city of Paris". I checked that the tokenizer is working properly, so I think it's something to do with the inference / context-handling code.
Some other random notes:
-
The models were twice their size when porting to ONNX - e.g. 169m model goes from 339MB to 679MB. I quantized down to 171MB, but that makes inference half the speed (~5 tokens/sec for quantized versus ~13 tokens/sec for original). I'm guessing the non-quantized version have been converted from f16 to f32, hence the size doubling? The demo includes both the normal and quantized versions.
-
@BlinkDL Yes, I got
TypeError: cannot resolve operator 'Max' with opsets: ai.onnx v13
when trying to use the WebGL backend. How would I go about removingRWKV_HEAD_QK
andRWKV-ffnPre
? I made a conversion notebook here: https://colab.research.google.com/github/josephrocca/rwkv-v4-web/blob/main/RWKV_v4_ONNX_conversion.ipynb Is it as simple as adding a few commands to that, or is there more work involved? -
The WebGL backend doesn't work with quantized models. It gives this error:
TypeError: cannot resolve operator 'DequantizeLinear' with opsets: ai.onnx v13, com.microsoft.experimental v1, ai.onnx.preview.training v1, ai.onnx.training v1, com.ms.internal.nhwc v17, org.pytorch.aten v1, com.microsoft.nchwc v1, ai.onnx.ml v3, com.microsoft v1
-
I used a very overkill approach to getting the tokenizer working... https://github.com/josephrocca/tokenizers-pyodide I haven't looked into how different the tokenizer is from gpt 2/3, but if it's similar, then I guess it shouldn't be too hard to make an edited version of this https://github.com/josephrocca/gpt-2-3-tokenizer ?
For example, if you prompt the 430m model with "The capital of France is" it continues with "first of the, the city of Paris"
That seems familiar!
The => first
The capital => of
The capital of => the
The capital of France => ,
The capital of France is => the
It looks like you display the outputs during the prompt-feeding stage, which happens one token at a time.
That should fix it:
let token = greedySampling(results.x.data);
if (promptTokens.length == 0) {
+ if(streamingCallback) streamingCallback(token);
ctx.push( token );
} else {
ctx.push( promptTokens.shift() );
}
-
- if(streamingCallback) streamingCallback(token);
feeds.xx_att = results.xx_att_r;
feeds.aa_att = results.aa_att_r;
@josephrocca I had to host the demo locally because huggingface keeps terminating the model downloads for some reason, but otherwise I can confirm that it works on my machine. Good job with getting the tokenizer and the quantization working!
I'm guessing the non-quantized version have been converted from f16 to f32, hence the size doubling?
Yeah, that's what's happening. RWKV-v4 is bf16 which can't be losslessly converted to fp16, so fp32 is the next best option. The fp32-converted model also compresses really well since half the bytes in it are zero.
- @BlinkDL Yes, I got
TypeError: cannot resolve operator 'Max' with opsets: ai.onnx v13
when trying to use the WebGL backend. How would I go about removingRWKV_HEAD_QK
andRWKV-ffnPre
? I made a conversion notebook here: https://colab.research.google.com/github/josephrocca/rwkv-v4-web/blob/main/RWKV_v4_ONNX_conversion.ipynb Is it as simple as adding a few commands to that, or is there more work involved?
take a look at src/model_run.py.
for the pile model, self.model_type == 'RWKV' and RWKV_HEAD_QK_DIM = 0 so you can remove some useless code.
moreover, use https://github.com/daquexian/onnx-simplifier to optimizer the onnx model
And the onnx version might work for AMD & Intel gpus. The DirectML backend supports them (on win10).
I tried that for RWKV-1.
Yeah, that's what's happening. RWKV-v4 is bf16 which can't be losslessly converted to fp16, so fp32 is the next best option. The fp32-converted model also compresses really well since half the bytes in it are zero.
You can loseless "transform" bf16 to fp16, and the idea is to use the same raw binary value. The float value will be totally different, but you can do an inverse transform in JS to loselessly recover the original bf16.
@AXKuhta Could have sworn I replied here earlier, sorry - apparently I didn't click send. I fixed the demo according to your comment soon after you posted it - thanks for your help!! Strange that huggingface is terminating the download for you... 🤔
@BlinkDL Thanks for the tips! I'll look into the stuff you've mentioned.
Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the
.pth
model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.
The python code for RWKV-2 weight conversion to .bin (for tf.js):
w = torch.load(MODEL_NAME + '.pth')
for x in w.keys():
if 'copy_mask' in x: # this is for headQK which is not used in pile models
continue
print(x, w[x].shape)
# we are doing some pre-computations here. change them to match RWKV-4. or you can just skip all of them and do everything in js first.
if '.time_' in x:
w[x] = w[x].squeeze()
if '.time_decay' in x:
w[x] = torch.exp(-torch.exp(w[x]))
if '.time_first' in x:
w[x] = torch.exp(w[x])
w[x].numpy().tofile(f'20220425/{x}.bin')
You can gradually port it to RWKV-4 by matching the outputs for each layer.
The Chinese RWKV-2 has a better UI: https://github.com/BlinkDL/AI-Writer/blob/main/docs/index.html
The English RWKV-2: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng
@AXKuhta Could have sworn I replied here earlier, sorry - apparently I didn't click send. I fixed the demo according to your comment soon after you posted it
Add top-p top-k and temperature and then it's very usable :)
It looks like the webgl backend has a lot of limitations. I did some testing by stripping out different parts of the model in order to see if I can get anything at all to work on the webgl backend. I think I got like four different error messages with different combinations. The bottom line is that I can't even get a matmul to work.
https://github.com/AXKuhta/RWKV-LM/blob/matmul/RWKV-v4/matmul_test.py
https://github.com/AXKuhta/RWKV-LM/blob/matmul/RWKV-v4/index.html
It does work on the wasm backend!
EDIT: It actually works on webgl if you do this: AXKuhta@75ad160
I have been able to force the full model to run on webgl, but it doesn't produce anything coherent, so something's still broken:
https://github.com/AXKuhta/RWKV-LM/tree/onnx_webgl
@BlinkDL The "cannot resolve operator 'Max' with opsets: ai.onnx v13" error was caused by torch.maximum(pp, ww)
and I was able to suppress it by using torch.max(torch.stack([pp, ww]), 0).values
instead. I also had to add a bunch of .view([768,1]) around matmul operations and then fix layer_norm() from producing NaNs. Now it looks like self.FF() always produces zeroes on webgl, but I'm not sure yet as to why nope, self.FF() does produce something.
I have been able to force the full model to run on webgl, but it doesn't produce anything coherent, so something's still broken:
https://github.com/AXKuhta/RWKV-LM/tree/onnx_webgl
@BlinkDL The "cannot resolve operator 'Max' with opsets: ai.onnx v13" error was caused by
torch.maximum(pp, ww)
and I was able to suppress it by usingtorch.max(torch.stack([pp, ww]), 0).values
instead. I also had to add a bunch of .view([768,1]) around matmul operations and then fix layer_norm() from producing NaNs.Now it looks like self.FF() always produces zeroes on webgl, but I'm not sure yet as to whynope, self.FF() does produce something.
That's great. Could you check whether https://github.com/daquexian/onnx-simplifier can help? Use https://github.com/lutzroeder/netron to visualize models.
And then you can print() the outputs of interesting layers to find the culprit... gradually matching the results of webgl vs wasm.
@BlinkDL After some painstaking debugging I got it to produce coherent output on webgl. The fix was really bizarre: add + 0.0
in a bunch of places. Some nodes on the ONNX graph that follow matmul+reshape operations kept getting bugged inputs that looked like a single value across all 768 elements. Performing +0.0 with the bugged input fixes it.
Here's the changes: https://github.com/AXKuhta/RWKV-LM/commits/onnx_webgl
Could you check whether https://github.com/daquexian/onnx-simplifier can help?
I did try onnx-simplifier with RWKV-3, but it didn't find much to simplify. The graph was almost unchanged. I will retest with RWKV-4 though.
@AXKuhta Nice! Can you upload the webgl-compatible 169m/430m models to hugging face so I can add them to the web demo?
Also, I wonder if the +0.0 bug is something that would be worth reporting to the ONNX runtime team?
@josephrocca I think it's better to keep all the web models in one place so I made two PRs in your huggingface repository. Oh, and by the way, I also improved my initial index.html a little to not create new tensors inside the loop and to remove leading_pad(). I think you should integrate these changes into your demo too.
I ran some performance tests with the hardware that I have available:
All tests performed in Chromium
169m model
========= WASM =========
Intel Core i7 2760QM: 280ms per token
Intel Core i7 6650U: 204ms per token
AMD A10-7800: 331ms per token
Snapdragon 865: 233ms per token
========= WebGL =========
Intel Core i7 2760QM iGPU 600ms per token
Nvidia GeForce 520MX 305ms per token
Intel Core i7 6650U iGPU 192ms per token
AMD A10-7800 iGPU 232ms per token
Snapdragon 865 iGPU: Produces NaNs
These numbers are not very impressive 😹
I didn't try in on a real GPU with a wide memory bus, but I suspect it won't perform massively better.
There are three different webgl bug reports to be made to onnxruntime:
- Matmuls like [768, 768] @ [768] complain about dimension mismatch, must be converted to [768, 768] @ [768, 1]
- NaNs produced by layer_norm() if there are negative inputs
- This strange +0.0 stuff if I can reproduce it in a standalone fashion
@AXKuhta Maybe there are some hidden bottlenecks :) Check the time consumption of all major functions and code fragments.
@AXKuhta Thanks! Great work. I've always struggled with the WebGL backend - I'm guessing that it doesn't get as much attention as wasm because it isn't a port of C++, but must be written from scratch IIUC. I'm hoping that WebGPU will change that situation and we'll get really serious GPU ML on the web.
Another factor RE performance could be relevant here is that wasm can just be faster for some models, but I'd have thought that this would only be the case for models that are very small. Some discussion in this article about tf.js: https://blog.tensorflow.org/2020/09/supercharging-tensorflowjs-webassembly.html
@BlinkDL The final [768, 50277] matmul is the slowest component. It's almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn't fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.
Gradually removing parts of the model until there is nothing left except input->output passthrough
Nvidia GeForce 520MX
169m model
Baseline full model 344ms N/A
Removed state store/restore 326ms -18ms
Removed final matmul 145ms -181ms
Removed 12 x self.FF() 60ms -85ms
Removed 12 x self.SA() 30ms -30ms
Removed 26 x self.LN() 16ms -14ms
Removed w.emb.weight[ctx[-1]] 0.7ms -15.3ms
@josephrocca Yeah, I think it's better to wait for WebGPU instead of pursuing WebGL any further. It seems to work well for graphics, but not so much for compute.
@BlinkDL The final [768, 50277] matmul is the slowest component. It's almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn't fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.
Probably can try tf.js for the final matmul and see if its faster
@AXKuhta @josephrocca And actually you can skip the final matmul when scanning the prompt (because we just need the hidden states).
I will provide some more efficient code soon to quickly generate the initial hidden states from prompt.
Oh and please check the speed of onnxruntime in pytorch :) I wonder if it will be faster.
You can actually install pytorch in Android too.
And actually you can skip the final matmul when scanning the prompt
@BlinkDL Ooh, somehow I didn't think of that before!
There is a "only_execute_path_to_fetches" switch in onnxruntime that can be used to make this work even with existing .onnx files. It looks like they forgot to expose it to JavaScript, so I had to make a custom build of ort-wasm-simd.wasm with that flag toggled in the source. I found that it actually works:
Intel Core i7 2760QM
169m model
WASM only_execute_path_to_fetches = true
Don't want the x output 158ms per token
Want the x output 258ms per token
I put the custom-built ort-wasm-simd.wasm and the index.html updated with fetches logic here if anyone wants to try this too.
I think it should be possible to pack both the RNN-style model and the GPT-style model into a single .onnx graph. Since the weights are shared between the two, there would only be a minimal increase in file size. I'll wait for the new GPT code (The current one doesn't run without CUDA btw).
Oh and please check the speed of onnxruntime in pytorch
Here's some performance numbers for RWKV-4 with pytorch and native onnxruntime:
Native pytorch + onnxruntime
169m model
Intel Core i7 2760QM Pytorch 79.3 ms/token
Intel Core i7 2760QM ONNX 152 ms/token Note: ONNX forced to use 8 threads to hit full CPU utilization
Intel Core i7 6650U Pytorch 62.1 ms/token
Intel Core i7 6650U ONNX 129 ms/token Note: ONNX forced to use 4 threads to hit full CPU utilization
Snapdragon 865 Pytorch 71.0 ms/token
Snapdragon 865 ONNX 180ms/token
But I think I made a bit of a mistake by not excluding sample_logits() from the pytorch version. It seems to take somewhere about ~10ms too. I need to rerun those tests with more caution.
EDIT: I totally forgot that my test_onnx.py had sample_logits() too, so these comparisons are fair after all.
Finally tested the webgl backend on a real GPU:
GTX 1060 6GB
webgl
169m model 68.6 ms/token
430m model 119 ms/token
As seen above, the 430m model also works on webgl now. It turns out my state store/restore code was breaking it: with a 24 layer model, it would attempt to stack 24 tensors at once, which would exceed the 16 input textures limit in WebGL. I worked around this by stacking 12 tensors at a time, twice, then using torch.cat() to glue two stacks.
The stacking code can be removed, but then the 430m model will have 120 individual inputs/outputs for state, which sound scary.
I guess this kind of vindicates the webgl backend? It does outperform wasm when used on a real GPU, and it can also run the non-quantized 430m model, while wasm can't. Of course, it is still significantly slower than native.
@josephrocca I opened two new PRs in your huggingface repo, one with the updated 430m webgl model and the other removing the outdated model.
@AXKuhta Thanks! I've accepted the pull request and updated the demo.
it can also run the non-quantized 430m model, while wasm can't
Note that the wasm runtime should be able to run the non-quantized, 1.7GB model with no problems if it had enough memory available. There's currently an arbitrary 2GB limit that needs to be raised: microsoft/onnxruntime#10957 (comment)
The memory limits should be gone completely once we get Memory64: https://github.com/WebAssembly/memory64
The memory limits should be gone completely once we get Memory64
So there is work ongoing to lift that limit. That's good to know 👍
@AXKuhta Thanks! I've accepted the pull request and updated the demo.
Please try the raw binary BF16 trick too :) #7 (comment)
And please show the progress (1/32 etc.) in the webpage
Another idea: the w.emb.weight shall be a simple Float32Array on CPU.
@BlinkDL That's totally doable.
Placing it into a separate file seems like a reasonable way to accomplish this.
It may fix the NaN problem with webgl on Snapdragon iGPUs, which happened exactly in w.emb.weight[ctx[-1]]. I think it was caused by a 4096x4096 texture size limit in Adreno GL ES drivers, unlike 16384x16384 on AMD/Nvidia/Intel. A 50277x768 tensor represented as a 6214x6214 texture thus fails to fit on Adreno. Final matmul is probably broken on Adreno too because of this.
Please try the raw binary BF16 trick too :) #7 (comment)
I don't quite understand the idea here. Do you mean storing bf16 weights in files and then converting them to fp32 or fp16 at runtime?
Placing it into a separate file seems like a reasonable way to accomplish this.
Yeah remove it from the ONNX model. The model will directly use the embedded vector as input. Saves VRAM and shall be much faster.
A 50277x768 tensor represented as a 6214x6214 texture thus fails to fit on Adreno
Unfortunately you will still have this problem when doing the head (output) matmul... But I think you can split 50277 into chunks.
Do you mean storing bf16 weights in files and then converting them to fp32 or fp16 at runtime?
Yeah storing bf16 weight as 16bit binary files. Then decode them at runtime in JS when loading the model.
See https://github.com/BlinkDL/AI-Writer/blob/main/docs/eng/index.html#L231 for loading binary weights