huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.

Home Page:https://huggingface.co/docs/diffusers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

bug of SlicedAttnProcessor

shinetzh opened this issue · comments

Describe the bug

image

"batch_size_attention // self.slice_size" will forget to compute the last batch of attention。

should be this:

"(batch_size_attention - 1) // self.slice_size + 1"

I have submit a pull request for this issue:
#8836

Reproduction

import torch
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor

attn = Attention(320, None, heads=5, dim_head=64)
slice_proc = SlicedAttnProcessor(slice_size=2)

hidden_states = torch.randn(1, 4096, 320)

result1 = slice_proc(attn=attn, hidden_states=hidden_states)

hidden_states = torch.cat([hidden_states]*2, dim=0)
result2 = slice_proc(attn=attn, hidden_states=hidden_states)

print(torch.sum(torch.abs(result1 - result2[:1])))  ### Ideally, the numerical difference should be 0

Logs

No response

System Info

  • __ Diffusers version: 0.29.0
  • Platform: Linux-4.14.105-1-tlinux3-0013-x86_64-with-glibc2.28
  • Running on a notebook?: No
  • Running on Google Colab?: No
  • Python version: 3.9.19
  • PyTorch version (GPU?): 2.1.2+cu118 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.23.3
  • Transformers version: 4.41.2
  • Accelerate version: 0.31.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.3
  • xFormers version: 0.0.23.post1+cu118
  • Accelerator: Tesla V100-SXM2-32GB, 32510 MiB VRAM
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@DN6 @yiyixuxu @sayakpaul

import torch
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor, AttnProcessor2_0

attn = Attention(320, None, heads=5, dim_head=64)
slice_proc1 = SlicedAttnProcessor(slice_size=1)
slice_proc2 = SlicedAttnProcessor(slice_size=2)
slice_proc3 = SlicedAttnProcessor(slice_size=3)
attn_proc = AttnProcessor2_0()

hidden_states = torch.randn(2, 4096, 320)

result1 = slice_proc1(attn=attn, hidden_states=hidden_states)

result2 = slice_proc2(attn=attn, hidden_states=hidden_states)
result3 = slice_proc3(attn=attn, hidden_states=hidden_states)
result = attn_proc(attn=attn, hidden_states=hidden_states)

print(torch.sum(torch.abs(result1 - result)))  ### Ideally, the numerical difference should be 0
print(torch.sum(torch.abs(result2 - result)))
print(torch.sum(torch.abs(result3 - result)))