omerbt / TokenFlow

Official Pytorch Implementation for "TokenFlow: Consistent Diffusion Features for Consistent Video Editing" presenting "TokenFlow" (ICLR 2024)

Home Page:https://diffusion-tokenflow.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

not compatible with diffusers 0.21+ [with workaround]

eps696 opened this issue · comments

everything runs ok on diffusers version 0.20 or below, while getting this error on diffusers 0.21:

File "F:\_neuro\SDfu\lib\tokenflow.py", line 185, in denoise_step
  noise_pred = self.unet(lat_in, t, conds).sample
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\diffusers\models\unet_2d_condition.py", line 1018, in forward
  sample = upsample_block(
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\diffusers\models\unet_2d_blocks.py", line 2227, in forward
  hidden_states = resnet(hidden_states, temb, scale=lora_scale)
File "C:\Users\eps\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
  return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'scale'

i've got python 2.0.1, xformers 0.0.21 (but again, it's only diffusers version that brings this error or not).
UPD: the error is only for pnp method, sdedit works ok.

there were some similar issues on their github, maybe it helps:
huggingface/diffusers#3348
huggingface/diffusers#5028

in fact, here is a workaround with a dummy argument:
change this line to def forward(input_tensor, temb, scale=None):

no idea if it breaks something about lora (as no idea if lora is compatible with tokenflow in general ::)

@eps696 Really helpful. Thanks! :-)