mosaicml / examples

Fast and flexible reference benchmarks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

FlashAttention Triton error on the MosaicBERT models other than base

Taytay opened this issue · comments

When I try to run MosaicBERT like this:

from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048')
# It's essential to set the attention_probs_dropout_prob to 0.1
#config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default of 512
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)

mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")

classifier("I [MASK] to the store yesterday.")

I get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124), in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
   [1123](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1123) try:
-> [1124](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124)     generator.visit(fn.parse())
   [1125](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1125) except CompilationError as e:

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017), in CodeGenerator.visit(self, node)
   [1016](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1016)     last_loc = self.builder.get_loc()
-> [1017](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017) ret = super().visit(node)
   [1018](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1018) # Reset the location to the last one before the visit

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407), in NodeVisitor.visit(self, node)
    [406](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:406) visitor = getattr(self, method, self.generic_visit)
--> [407](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407) return visitor(node)

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293), in CodeGenerator.visit_Module(self, node)
    [292](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:292) def visit_Module(self, node):
--> [293](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293)     ast.NodeVisitor.generic_visit(self, node)

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415), in NodeVisitor.generic_visit(self, node)
    [414](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:414)         if isinstance(item, AST):
--> [415](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415)             self.visit(item)
    [416](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:416) elif isinstance(value, AST):
...
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^
TypeError("dot() got an unexpected keyword argument 'trans_b'")

This appears to have been fixed a few days ago by @jacobfulano in the mosaic-bert-base repo:
https://huggingface.co/mosaicml/mosaic-bert-base/blob/ed2a544063a892b78823cba2858d1e098c0e6012/config.json

It looks like that removes FlashAttention? Does that mean that the speed increase from FA is also removed?

Here's how I can fix it in the meantime if someone else Googles and stumbles across this

config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base')
# It's essential to set the attention_probs_dropout_prob to 0.1, which mosaic-bert-base does. So we just update the alibi_starting_size:
config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default of 512
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)
mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")
classifier("I [MASK] to the store yesterday.")

Hi @Taytay,

Thanks for your comment. This code was written in the early days of Triton Flash Attention and we are in the process of updating to Flash Attention 2 with support for ALiBi (see #440 which updates to PyTorch 2 and the flash-attn package instead of triton). Earlier versions of Flash Attention did not support ALiBi, and we had to come up with a custom ALiBi implementation using Triton Flash Attention that depend on PyTorch 1.13 and triton==2.0.0.dev20221103 (here's the requirements file).

The config value of attention_probs_dropout_prob: 0.1 turns off Triton Flash Attention, while attention_probs_dropout_prob: 0.0 turns it on (you can see this in the source code here).

This was understandably confusing for a lot of people, so I have set the default in Hugging Face to attention_probs_dropout_prob: 0.1 (e.g. here) across all mosaicbert models. It should be noted that most of the benefits of Flash Attention come during training, and not necessarily by serving a loaded model.

Your code should now work with the config from mosaic-bert-base-seqlen-2048. Not that this does not use our custom Triton Flash Attention implementation.

config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048')
config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)
mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")
classifier("I [MASK] to the store yesterday.")

Thank you for the thorough explanation! I was trying training yesterday and ran into some more errors, so #440 is especially welcome! (I have a model I need to train on a large number of tokens, so the perf is going to be particularly helpful.)