torch compile support?
stillmatic opened this issue · comments
System Info
colab
torch: 2.2.1+cu121
bnb: 0.43.1
Reproduction
https://colab.research.google.com/drive/1SYd8E0SmELbYiVEwfPezReeI0ToNVYCk?usp=sharing
Expected behavior
I expect the functions to compile properly with torch.compile - big speed up if it works. The colab is a short example of some basic flows, adapted from https://huggingface.co/docs/bitsandbytes/main/en/reference/nn/linear4bit#bitsandbytes.nn.Linear4bit.example .
First, there is a bug with setattr
Unsupported: setattr(UserDefinedObjectVariable) <function Module.__setattr__ at 0x7c67c77f6f80>
from user code:
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 460, in forward
self.set_compute_type(x)
File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 415, in set_compute_type
self.compute_dtype = x.dtype
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch does not like changing the dtype here. So we instead force the model to use torch.float32
(works with bfloat16
too) and reinitialize the compute type in the layer.
for layer in quantized_model:
layer.compute_type_is_set = True
There are similar errors with bias but who uses bias? So we can disable those too.
The meatier error is now:
[2024-04-16 17:02:02,245] [0/1_1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py:415 in set_compute_type (Params4bit.to.to.to.Linear4bit.set_compute_type) (inline depth: 3)
[2024-04-16 17:02:02,245] [0/1_1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] self.compute_dtype = x.dtype
[2024-04-16 17:02:02,246] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x []
[2024-04-16 17:02:02,248] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR dtype [TensorVariable()]
[2024-04-16 17:02:02,249] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST self [ConstantVariable(dtype)]
[2024-04-16 17:02:02,250] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_ATTR compute_dtype [ConstantVariable(dtype), UnspecializedNNModuleVariable(Linear4bit)]
[2024-04-16 17:02:02,252] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
[2024-04-16 17:02:02,253] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object set_compute_type at 0x789318d28b30, file "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 411>
[2024-04-16 17:02:02,255] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
[2024-04-16 17:02:02,256] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object forward at 0x789318d28c90, file "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 442>
[2024-04-16 17:02:02,257] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
[2024-04-16 17:02:02,258] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object _call_impl at 0x789326b383a0, file "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1513>
[2024-04-16 17:02:02,259] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
---------------------------------------------------------------------------
Unsupported Traceback (most recent call last)
[<ipython-input-3-60e33e8a9b3e>](https://localhost:8080/#) in <cell line: 36>()
34 x = torch.randn(1, 64).to(device=0, dtype=torch.float32)
35 y_fp16 = fp16_model(x)
---> 36 y_quantized = quantized_model(x)
43 frames
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://localhost:8080/#) in unimplemented(msg)
191 def unimplemented(msg: str) -> NoReturn:
192 assert msg != os.environ.get("BREAK", False)
--> 193 raise Unsupported(msg)
194
195
Unsupported: call_method UserDefinedObjectVariable(Params4bit) t [] {}
from user code:
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 468, in forward
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
I'm tempted to believe that this is a PyTorch problem, but if there's a simple way to enable compilation with CUDA Graphs, that could be a big speedup.
Dear @stillmatic, thanks for providing a minimal reproducible example and the helpful description! torch.compile support is one of the things we're currently looking into very actively.
Next to the multi-platform-refactor, it's one of the high impact topics that need extra time and focus right now, making it so we can't be responsive on the issues right now, as I'm only one person working full-time on BNB right now.
Regarding torch.compile it's quite a complex topic to make things work under the hood, but we actually have a call with the PyTorch devs scheduled tmr, partially about this topic.
We'll let you know in this issue, when there's progress.