openai / consistency_models

Official repo for consistency models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

QKVFlashAttention unexpected parameters error, running in Google Colab

JonathanFly opened this issue · comments

I tried to generate samples in Colab and everything works except that I had to change this line of code in /cm/unet.py, clearing out factory_kwargs.

Not sure if this is a bug or I did something wrong. This is how I ran it: https://github.com/JonathanFly/consistency_models_colab_notebook/blob/main/Consistency_Models_Make_Samples.ipynb


class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal
commented

i also meet the same error, i guess this code base using a previously version of flash_attn.

Logging to /tmp/openai-2023-04-13-14-33-55-278549
creating model and diffusion...
Traceback (most recent call last):
  File "/content/consistency_models/scripts/image_sample.py", line 143, in <module>
    main()
  File "/content/consistency_models/scripts/image_sample.py", line 37, in main
    model, diffusion = create_model_and_diffusion(
  File "/content/consistency_models/cm/script_util.py", line 76, in create_model_and_diffusion
    model = create_model(
  File "/content/consistency_models/cm/script_util.py", line 140, in create_model
    return UNetModel(
  File "/content/consistency_models/cm/unet.py", line 612, in __init__
    AttentionBlock(
  File "/content/consistency_models/cm/unet.py", line 293, in __init__
    self.attention = QKVFlashAttention(channels, self.num_heads)
  File "/content/consistency_models/cm/unet.py", line 359, in __init__
    self.inner_attn = FlashAttention(
TypeError: __init__() got an unexpected keyword argument 'device'

Related minor issue, in th_evaluator.py, inception-2015-12-05.pt tries to download automatically but fails, and it doesn't seem like you can pass the path on the command line. Also is it supposed to automatically calculate stats a reference batch? (I'm probably trying to run the sample out of order?)

class FIDAndIS:
def init(
self,
softmax_batch_size=512,
clip_score_batch_size=512,
path="https://openaipublic.blob.core.windows.net/consistency/inception/inception-2015-12-05.pt",
):

commented

class FIDAndIS:
def init(

I use pip install flash-attn==0.2.8 solved it.

I use pip install flash-attn==0.2.8 solved it.

After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.

(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir
Logging to /tmp/openai-2023-04-19-10-51-33-746807
creating model and diffusion...
creating data loader...
loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt
creating the target model
training...
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Traceback (most recent call last):
  File "cm_train.py", line 171, in <module>
    main()
  File "cm_train.py", line 121, in main
    CMTrainLoop(
  File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop
    self.run_step(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step
    self.forward_backward(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward
    losses = compute_losses()
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses
    distiller = denoise_fn(x_t, t)
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn
    return self.denoise(model, x, t, **model_kwargs)[1]
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 765, in forward
    h = module(h, emb)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 77, in forward
    x = layer(x)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 308, in forward
    return checkpoint(
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward
    h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint)
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 368, in forward
    qkv, _ = self.inner_attn(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward
    output = flash_attn_unpadded_qkvpacked_func(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
    out, softmax_lse, S_dmask = _flash_attn_forward(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
    softmax_lse, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q.stride(-1) == 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[48068,1],0]
  Exit code:    1
commented

I use pip install flash-attn==0.2.8 solved it.

After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.

(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir
Logging to /tmp/openai-2023-04-19-10-51-33-746807
creating model and diffusion...
creating data loader...
loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt
creating the target model
training...
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Traceback (most recent call last):
  File "cm_train.py", line 171, in <module>
    main()
  File "cm_train.py", line 121, in main
    CMTrainLoop(
  File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop
    self.run_step(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step
    self.forward_backward(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward
    losses = compute_losses()
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses
    distiller = denoise_fn(x_t, t)
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn
    return self.denoise(model, x, t, **model_kwargs)[1]
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 765, in forward
    h = module(h, emb)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 77, in forward
    x = layer(x)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 308, in forward
    return checkpoint(
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward
    h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint)
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 368, in forward
    qkv, _ = self.inner_attn(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward
    output = flash_attn_unpadded_qkvpacked_func(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
    out, softmax_lse, S_dmask = _flash_attn_forward(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
    softmax_lse, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q.stride(-1) == 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[48068,1],0]
  Exit code:    1

I got the same problem as well

I use pip install flash-attn==0.2.8 solved it.

After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.

(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir
Logging to /tmp/openai-2023-04-19-10-51-33-746807
creating model and diffusion...
creating data loader...
loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt
creating the target model
training...
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Traceback (most recent call last):
  File "cm_train.py", line 171, in <module>
    main()
  File "cm_train.py", line 121, in main
    CMTrainLoop(
  File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop
    self.run_step(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step
    self.forward_backward(batch, cond)
  File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward
    losses = compute_losses()
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses
    distiller = denoise_fn(x_t, t)
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn
    return self.denoise(model, x, t, **model_kwargs)[1]
  File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise
    model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 765, in forward
    h = module(h, emb)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 77, in forward
    x = layer(x)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 308, in forward
    return checkpoint(
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward
    h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint)
  File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
    return func(*inputs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/consistency_models/cm/unet.py", line 368, in forward
    qkv, _ = self.inner_attn(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward
    output = flash_attn_unpadded_qkvpacked_func(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func
    return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
    out, softmax_lse, S_dmask = _flash_attn_forward(
  File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
    softmax_lse, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q.stride(-1) == 1 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:

  Process name: [[48068,1],0]
  Exit code:    1

I got the same problem as well

Had you solved this problem? I don't even know what the error message means.

Solution:
Do the following changes in File "/content/consistency_models/cm/unet.py", line 359, in init

-        self.inner_attn = FlashAttention(
-            attention_dropout=attention_dropout, **factory_kwargs
-        )
+        self.inner_attn = FlashAttention(attention_dropout=attention_dropout)
commented

I tried that, but it didn't solve the problem, were there any other changes you made?

Since I'm running on V100, I also had to disable flash-attention (apparently it only works on A100)

index 3fe5184..d9f7c2f 100644
--- a/cm/unet.py
+++ b/cm/unet.py
@@ -270,7 +270,7 @@ class AttentionBlock(nn.Module):
         num_heads=1,
         num_head_channels=-1,
         use_checkpoint=False,
-        attention_type="flash",
+        attention_type="default", #"flash", # disable flash-attention by default in order to run on V100
         encoder_channels=None,
         dims=2,
         channels_last=False,

Still doesn't work for me. This is what i get for CD on Imagenet 64, the similar result I get with EDM
image

Still doesn't work for me. This is what i get for CD on Imagenet 64, the similar result I get with EDM image

I cannot obtain images of similar quality to those in the paper

@boxwayne @aarontan-git @asanakoy For the stride issue, I think it's the rearrange issue because of flashAttn version.

        qkv = self.rearrange(
            qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads
        )
        # print(qkv.shape, qkv.stride())
        qkv, _ = self.inner_attn(
            qkv.contiguous(),
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            causal=self.causal,
        )

The print result is torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256), which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before calling inner_attn:

qkv=qkv.contiguous()

Let me know if that solves the issue. I tested on my side and it works.

commented

The flash-attn I installed is version 1.0.2, no problem.
image

commented

@boxwayne @aarontan-git @asanakoy For the stride issue, I think it's the rearrange issue because of flashAttn version.

        qkv = self.rearrange(
            qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads
        )
        # print(qkv.shape, qkv.stride())
        qkv, _ = self.inner_attn(
            qkv.contiguous(),
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            causal=self.causal,
        )

The print result is torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256), which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before calling inner_attn:

qkv=qkv.contiguous()

Let me know if that solves the issue. I tested on my side and it works.

I tried your fix, and got the following warning message when trying to run an imagenet consistency training:

Grad strides do not match bucket view strides. This may indicate grad was not created according │·
to the gradient layout contract, or that the param's strides changed since DDP was constructed.  This│·
 is not an error, but may impair performance.                                                        │·
grad.sizes() = [384, 384, 1, 1], strides() = [384, 1, 384, 384]                                      │·
bucket_view.sizes() = [384, 384, 1, 1], strides() = [384, 1, 1, 1] (Triggered internally at ../torch/│·
csrc/distributed/c10d/reducer.cpp:325.)

@aarontan-git You could just leave this warning there if you don't care about it. If you want to fix this, you should check the codes and see which part involves the gradient stride change. And do the tensor storage stride modification to avoid the warning.

The flash-attn I installed is version 1.0.2, no problem. image

When installing version 1.0.2, the following error will occur; How did you solve it?
image

The flash-attn I installed is version 1.0.2, no problem. image

When installing version 1.0.2, the following error will occur; How did you solve it? image

As Jonathan said at the top, change this line of code in /cm/unet.py, clearing out factory_kwargs:

class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal

我安装的闪存是1.0.2版本,没问题。图像

安装版本 1.0.2 时,将出现以下错误; 你是怎么解决的?图像

正如 Jonathan 在顶部所说,更改 /cm/unet.py 中的这一行代码,清除factory_kwargs:

class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal

What GPU are you using?
This doesn't seem to have anything to do with the version of 'flash-attn'; I make 'attention_ type="flash" ---> "default", the codes can run, but the result is poor. If not changed, there will be the following error message:
image

my GPU is V100........

我安装的闪存是1.0.2版本,没问题。图像

安装版本 1.0.2 时,将出现以下错误; 你是怎么解决的?图像

正如 Jonathan 在顶部所说,更改 /cm/unet.py 中的这一行代码,清除factory_kwargs:

class QKVFlashAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        batch_first=True,
        attention_dropout=0.0,
        causal=False,
        device=None,
        dtype=None,
        **kwargs,
    ) -> None:
        from einops import rearrange
        from flash_attn.flash_attention import FlashAttention

        assert batch_first
        #factory_kwargs = {"device": device, "dtype": dtype}
        factory_kwargs = {}
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal

What GPU are you using? This doesn't seem to have anything to do with the version of 'flash-attn'; I make 'attention_ type="flash" ---> "default", the codes can run, but the result is poor. If not changed, there will be the following error message: image

my GPU is V100........

I used a single A100.