RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false
adammenges opened this issue · comments
Adam Menges commented
Got the following error when trying to use the Notebook (as is, no modifications). 5th cell, the one running pipe(...)
RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
Any ideas?
Full trace below:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 images = pipe(prompt,negative_prompt,
2 batch_size = 2, #batch size
3 num_inference_steps=30, # sampling step
4 height = 896,
5 width = 640,
6 end_steps = 1, # The number of steps to end the attention double version (specified in a ratio of 0-1. If it is 1, attention double version will be applied in all steps, with 0 being the normal generation)
7 base_ratio=0.2, # Base ratio, the weight of base prompt, if 0, all are regional prompts, if 1, all are base prompts
8 seed = 4396, # random seed
9 )
Cell In[1], line 108, in RegionalGenerator.__call__(self, prompts, negative_prompt, batch_size, height, width, guidance_scale, num_inference_steps, seed, base_ratio, end_steps)
106 #predict noise
107 with torch.no_grad():
--> 108 noise_pred = self.unet(sample = latent_model_input,timestep = t,encoder_hidden_states=text_embs).sample
110 #negative CFG
111 noise_pred_text, noise_pred_negative= noise_pred.chunk(2)
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py:905, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, encoder_attention_mask, return_dict)
903 for downsample_block in self.down_blocks:
904 if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
--> 905 sample, res_samples = downsample_block(
906 hidden_states=sample,
907 temb=emb,
908 encoder_hidden_states=encoder_hidden_states,
909 attention_mask=attention_mask,
910 cross_attention_kwargs=cross_attention_kwargs,
911 encoder_attention_mask=encoder_attention_mask,
912 )
913 else:
914 sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py:993, in CrossAttnDownBlock2D.forward(self, hidden_states, temb, encoder_hidden_states, attention_mask, cross_attention_kwargs, encoder_attention_mask)
991 else:
992 hidden_states = resnet(hidden_states, temb)
--> 993 hidden_states = attn(
994 hidden_states,
995 encoder_hidden_states=encoder_hidden_states,
996 cross_attention_kwargs=cross_attention_kwargs,
997 attention_mask=attention_mask,
998 encoder_attention_mask=encoder_attention_mask,
999 return_dict=False,
1000 )[0]
1002 output_states = output_states + (hidden_states,)
1004 if self.downsamplers is not None:
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/transformer_2d.py:291, in Transformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, class_labels, cross_attention_kwargs, attention_mask, encoder_attention_mask, return_dict)
289 # 2. Blocks
290 for block in self.transformer_blocks:
--> 291 hidden_states = block(
292 hidden_states,
293 attention_mask=attention_mask,
294 encoder_hidden_states=encoder_hidden_states,
295 encoder_attention_mask=encoder_attention_mask,
296 timestep=timestep,
297 cross_attention_kwargs=cross_attention_kwargs,
298 class_labels=class_labels,
299 )
301 # 3. Output
302 if self.is_input_continuous:
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention.py:170, in BasicTransformerBlock.forward(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, timestep, cross_attention_kwargs, class_labels)
165 if self.attn2 is not None:
166 norm_hidden_states = (
167 self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
168 )
--> 170 attn_output = self.attn2(
171 norm_hidden_states,
172 encoder_hidden_states=encoder_hidden_states,
173 attention_mask=encoder_attention_mask,
174 **cross_attention_kwargs,
175 )
176 hidden_states = attn_output + hidden_states
178 # 3. Feed-forward
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1190, in Module._call_impl(self, *input, **kwargs)
1186 # If we don't have any hooks, we want to skip the rest of the logic in
1187 # this function, and just call forward.
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:321, in Attention.forward(self, hidden_states, encoder_hidden_states, attention_mask, **cross_attention_kwargs)
317 def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
318 # The `Attention` class can call different attention processors / attention functions
319 # here we simply pass along all tensors to the selected processor class
320 # For standard processors that are defined here, `**cross_attention_kwargs` is empty
--> 321 return self.processor(
322 self,
323 hidden_states,
324 encoder_hidden_states=encoder_hidden_states,
325 attention_mask=attention_mask,
326 **cross_attention_kwargs,
327 )
File ~/jupyter/.venv/lib/python3.10/site-packages/diffusers/models/attention_processor.py:1046, in XFormersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
1043 key = attn.head_to_batch_dim(key).contiguous()
1044 value = attn.head_to_batch_dim(value).contiguous()
-> 1046 hidden_states = xformers.ops.memory_efficient_attention(
1047 query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1048 )
1049 hidden_states = hidden_states.to(query.dtype)
1050 hidden_states = attn.batch_to_head_dim(hidden_states)
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:197, in memory_efficient_attention(query, key, value, attn_bias, p, scale, op)
117 def memory_efficient_attention(
118 query: torch.Tensor,
119 key: torch.Tensor,
(...)
125 op: Optional[AttentionOp] = None,
126 ) -> torch.Tensor:
127 """Implements the memory-efficient attention mechanism following
128 `"Self-Attention Does Not Need O(n^2) Memory" <[http://arxiv.org/abs/2112.05682>`_](http://arxiv.org/abs/2112.05682%3E%60_).
129
(...)
195 :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
196 """
--> 197 return _memory_efficient_attention(
198 Inputs(
199 query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
200 ),
201 op=op,
202 )
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:293, in _memory_efficient_attention(inp, op)
288 def _memory_efficient_attention(
289 inp: Inputs, op: Optional[AttentionOp] = None
290 ) -> torch.Tensor:
291 # fast-path that doesn't require computing the logsumexp for backward computation
292 if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
--> 293 return _memory_efficient_attention_forward(
294 inp, op=op[0] if op is not None else None
295 )
297 output_shape = inp.normalize_bmhk()
298 return _fMHA.apply(
299 op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale
300 ).reshape(output_shape)
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/__init__.py:313, in _memory_efficient_attention_forward(inp, op)
310 else:
311 _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
--> 313 out, *_ = op.apply(inp, needs_gradient=False)
314 return out.reshape(output_shape)
File ~/jupyter/.venv/lib/python3.10/site-packages/xformers/ops/fmha/cutlass.py:106, in FwOp.apply(cls, inp, needs_gradient)
104 causal = isinstance(inp.attn_bias, LowerTriangularMask)
105 cu_seqlen_k, cu_seqlen_q, max_seqlen_q = _get_seqlen_info(inp)
--> 106 out, lse = cls.OPERATOR(
107 query=inp.query,
108 key=inp.key,
109 value=inp.value,
110 cu_seqlens_q=cu_seqlen_q,
111 cu_seqlens_k=cu_seqlen_k,
112 max_seqlen_q=max_seqlen_q,
113 compute_logsumexp=needs_gradient,
114 causal=causal,
115 scale=inp.scale,
116 )
117 ctx: Optional[Context] = None
118 if needs_gradient:
File ~/jupyter/.venv/lib/python3.10/site-packages/torch/_ops.py:442, in OpOverloadPacket.__call__(self, *args, **kwargs)
437 def __call__(self, *args, **kwargs):
438 # overloading __call__ to ensure torch.ops.foo.bar()
439 # is still callable from JIT
440 # We save the function ptr as the `op` attribute on
441 # OpOverloadPacket to access it here.
--> 442 return self._op(*args, **kwargs or {})
RuntimeError: Expected query.size(0) == key.size(0) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)