bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.

Home Page:https://huggingface.co/docs/bitsandbytes/main/en/index

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch compile support?

stillmatic opened this issue · comments

System Info

colab

torch: 2.2.1+cu121
bnb: 0.43.1

Reproduction

https://colab.research.google.com/drive/1SYd8E0SmELbYiVEwfPezReeI0ToNVYCk?usp=sharing

Expected behavior

I expect the functions to compile properly with torch.compile - big speed up if it works. The colab is a short example of some basic flows, adapted from https://huggingface.co/docs/bitsandbytes/main/en/reference/nn/linear4bit#bitsandbytes.nn.Linear4bit.example .

First, there is a bug with setattr

Unsupported: setattr(UserDefinedObjectVariable) <function Module.__setattr__ at 0x7c67c77f6f80>

from user code:
   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 460, in forward
    self.set_compute_type(x)
  File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 415, in set_compute_type
    self.compute_dtype = x.dtype

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

torch does not like changing the dtype here. So we instead force the model to use torch.float32 (works with bfloat16 too) and reinitialize the compute type in the layer.

for layer in quantized_model:
  layer.compute_type_is_set = True

There are similar errors with bias but who uses bias? So we can disable those too.

The meatier error is now:

[2024-04-16 17:02:02,245] [0/1_1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py:415 in set_compute_type (Params4bit.to.to.to.Linear4bit.set_compute_type) (inline depth: 3)
[2024-04-16 17:02:02,245] [0/1_1] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]                 self.compute_dtype = x.dtype
[2024-04-16 17:02:02,246] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x []
[2024-04-16 17:02:02,248] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_ATTR dtype [TensorVariable()]
[2024-04-16 17:02:02,249] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST self [ConstantVariable(dtype)]
[2024-04-16 17:02:02,250] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_ATTR compute_dtype [ConstantVariable(dtype), UnspecializedNNModuleVariable(Linear4bit)]
[2024-04-16 17:02:02,252] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
[2024-04-16 17:02:02,253] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object set_compute_type at 0x789318d28b30, file "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 411>
[2024-04-16 17:02:02,255] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
[2024-04-16 17:02:02,256] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object forward at 0x789318d28c90, file "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 442>
[2024-04-16 17:02:02,257] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint
[2024-04-16 17:02:02,258] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object _call_impl at 0x789326b383a0, file "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1513>
[2024-04-16 17:02:02,259] [0/1_1] torch._dynamo.symbolic_convert: [DEBUG] empty checkpoint

---------------------------------------------------------------------------

Unsupported                               Traceback (most recent call last)

[<ipython-input-3-60e33e8a9b3e>](https://localhost:8080/#) in <cell line: 36>()
     34 x = torch.randn(1, 64).to(device=0, dtype=torch.float32)
     35 y_fp16 = fp16_model(x)
---> 36 y_quantized = quantized_model(x)

43 frames

[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py](https://localhost:8080/#) in unimplemented(msg)
    191 def unimplemented(msg: str) -> NoReturn:
    192     assert msg != os.environ.get("BREAK", False)
--> 193     raise Unsupported(msg)
    194 
    195 

Unsupported: call_method UserDefinedObjectVariable(Params4bit) t [] {}

from user code:
   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/bitsandbytes/nn/modules.py", line 468, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I'm tempted to believe that this is a PyTorch problem, but if there's a simple way to enable compilation with CUDA Graphs, that could be a big speedup.

Dear @stillmatic, thanks for providing a minimal reproducible example and the helpful description! torch.compile support is one of the things we're currently looking into very actively.

Next to the multi-platform-refactor, it's one of the high impact topics that need extra time and focus right now, making it so we can't be responsive on the issues right now, as I'm only one person working full-time on BNB right now.

Regarding torch.compile it's quite a complex topic to make things work under the hood, but we actually have a call with the PyTorch devs scheduled tmr, partially about this topic.

We'll let you know in this issue, when there's progress.