huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.

Home Page:https://huggingface.co/docs/diffusers

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

StableDiffusionLatentUpscalePipeline - positive/negative prompt embeds support

DeTeam opened this issue · comments

I'm trying to deploy the smallest possible SD inpainting model. My production deployment only needs unet+vae+ipadapter weights with prompt and ip adapter image embeds pre-generated. Works well!

Now I wanted to try latent upscaler from diffusers and realized it currently doesn't support pre-generated embeds. Would probably be nice to keep its API aligned with the rest and add them.

Describe the solution you'd like.

Harmonizing inputs on the StableDiffusionLatentUpscalePipeline with other more frequently used pipelines would be nice.

would be very nice indeed!
would you be willing to open a PR? if not we can ask the community to see if anyone else wants to help :)

@yiyixuxu sorry, I don't have capacity for a PR right now (unfamiliar with the codebase, assuming that testing would also take a while).

@yiyixuxu LatentUpscaler use two text embeds hidden_states and pooler_output for prompt and negative prompt
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py#L148-L149

should we change prompt_embeds from torch.FloatTensor to BaseModelOutputWithPooling?

@rootonchair
we can:

  1. create a encode_prompt that's consistent with the method in other pipelines https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L275 (i.e. it should return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
  2. and then refactor _encode_prompt (similar to )
    • we can use the encode_prompt we just created
      prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt(...)
    if do_classifier_free_guidance:  
        prompt_embeds = ...
        pooled_prompt_embeds = ....
    else:
       ...
    • also deprecate it,

Thanks for your guidance @yiyixuxu. Will open a PR soon