AUTOMATIC1111 / stable-diffusion-webui-tensorrt

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

SDXL Support

CyberTimon opened this issue · comments

Hello

Is SDXL support planned, as SDXL is slow on most computers?

Kind regards,
Timon Käch

Hello

Is SDXL support planned, as SDXL is slow on most computers?

Kind regards, Timon Käch

already try can, but need modify code

speed from 6.67it/s up to 12.10 it/s w 960:h:1024 step 21


1. export to onnx the new method

`import os

from modules import sd_hijack, sd_unet
from modules import shared, devices
import torch

def export_current_unet_to_onnx(filename, opset_version=17):
x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype)
timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500
context = torch.randn(1, 77, 2048).to(devices.device, devices.dtype)
y = torch.randn(1, 2816).to(devices.device, devices.dtype)
def disable_checkpoint(self):
if getattr(self, 'use_checkpoint', False) == True:
self.use_checkpoint = False
if getattr(self, 'checkpoint', False) == True:
self.checkpoint = False

shared.sd_model.model.diffusion_model.apply(disable_checkpoint)

sd_unet.apply_unet("None")
sd_hijack.model_hijack.apply_optimizations('None')

os.makedirs(os.path.dirname(filename), exist_ok=True)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
shared.sd_model.model.diffusion_model = shared.sd_model.model.diffusion_model.to(device)

with devices.autocast():
    torch.onnx.export(
        shared.sd_model.model.diffusion_model,
        (x, timesteps, context,y),
        filename,
        export_params=True,
        opset_version=opset_version,
        do_constant_folding=True,
        input_names=['x', 'timesteps', 'context','y'],
        output_names=['output'],
        dynamic_axes={
            'x': {0: 'batch_size', 2: 'height', 3: 'width'},
            'timesteps': {0: 'batch_size'},
            'context': {0: 'batch_size', 1: 'sequence_length'},
            'y':{0:'batch_size'},
            'output': {0: 'batch_size'},
        },
    )

sd_hijack.model_hijack.apply_optimizations()
sd_unet.apply_unet()

`

3.hijack the UNetModel_forwardy,

/modules/sd_hijack.py

`
...
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward

    ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
    
    if not hasattr(sgm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
        sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = sgm.modules.diffusionmodules.openaimodel.UNetModel.forward

    sgm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forwardy
        

def undo_hijack(self, m):
    if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
        m.cond_stage_model = m.cond_stage_model.wrapped

    elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
        m.cond_stage_model = m.cond_stage_model.wrapped

        model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
        if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
            model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
    elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
        m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
        m.cond_stage_model = m.cond_stage_model.wrapped

    undo_optimizations()
    undo_weighted_forward(m)

    self.apply_circular(False)
    self.layers = None
    self.clip = None

    ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
    sgm.modules.diffusionmodules.openaimodel.UNetModel.forward = sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui

...
`

3. modules/sd_unet.py

`
...
class SdUnet(torch.nn.Module):
def forward(self, x, timesteps, context, *args, **kwargs):
raise NotImplementedError()

def activate(self):
    pass

def deactivate(self):
    pass

def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)

return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)

def UNetModel_forwardy(self, x, timesteps=None, context=None, y=None, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, y, **kwargs)

return sgm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context,y, **kwargs)

here can use same method

...
`

4. extensions/stable-diffusion-webui-tensorrt/scripts/trt.py

`
def forward(self, x, timesteps, context,*args, **kwargs):
a,b,c,d=x.shape

    #print(x.shape,timesteps.shape,context.shape)
        
    if a==1:
        self.infer({"x": x, "timesteps": timesteps, "context": context})
        #print(self)

        return self.buffers["output"].to(dtype=x.dtype, device=devices.device)
    else:
        images=[]
        for i in range(a):
            with contextlib.suppress(Exception):
                s = x[i].unsqueeze(0)
                t = timesteps[i].unsqueeze(0)
                c = context[i].unsqueeze(0)
                if args is not None and args.__len__()!=0:
                    y = args[0][i].unsqueeze(0)
                    self.infer({"x": s, "timesteps": t, "context": c,"y":y})
                #print(self)
                else:
                    self.infer({"x": s, "timesteps": t, "context": c})

                tmp_img= self.buffers["output"].to(dtype=x.dtype, device=devices.device)
                new_var = tmp_img
                images.append(new_var)
        return torch.cat(images, dim=0)

`

5. and for found 2 device problem

you need one by one find out it add model.to(devices.device) or easy way use model.cuda() // have maybe 3-4 place need modify

7. export onnx to trt my command

"{full_path}/trtexec" --onnx="{full_path}/models/Unet-onnx/ttt.onnx" --saveEngine="{full_path}/models/Unet-trt/ttt.trt" --minShapes=x:1x4x64x64,context:1x77x2048,timesteps:1 --maxShapes=x:1x4x128x120,context:1x77x2048,timesteps:1 --fp16

Hey, thank you so much for the fast answer. Will try it out soon.
Is 1024x1024 not possible? Only 960x1024?

Hey, thank you so much for the fast answer. Will try it out soon. Is 1024x1024 not possible? Only 960x1024?

cant sure ,maxShapes=x:1x4x128x120 cant over this size
if use maxShapes=x:1x4x128x128 the trtexec will popup the error