openai / consistencydecoder

Consistency Distilled Diff VAE

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Open-source model definitions

patrickvonplaten opened this issue · comments

Hey @gabgoh,

Super cool that you're open-sourcing the consistency decoder of Dalle-3 with a MIT license ❤️

Any chance you can also add the model definitions of the torch.jit binary? Otherwise it'll be quite difficult to port the model to other libraries.

I can try to get a more human readable version of this pushed, but does decoder_consistency.ckpt.code work for the time being? The complete model definition is in there.

It is not quite as readable, as you need to plumbing through varies wrappers, for example, you get at high-level:

class ConvUNetVAE(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : Optional[bool]
  blocks : __torch__.torch.nn.modules.container.ModuleDict
  def forward(self: __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ConvUNetVAE,
    x: Tensor,
    t: Tensor,
    features: Tensor) -> Tensor:
    blocks = self.blocks
    output = blocks.output
    blocks0 = self.blocks
    up_0_conv_3 = blocks0.up_0_conv_3
    blocks1 = self.blocks
    up_0_conv_2 = blocks1.up_0_conv_2
    blocks2 = self.blocks
    up_0_conv_1 = blocks2.up_0_conv_1
    blocks3 = self.blocks
    up_0_conv_0 = blocks3.up_0_conv_0
    blocks4 = self.blocks
    up_1_upsamp = blocks4.up_1_upsamp
    blocks5 = self.blocks
    up_1_conv_3 = blocks5.up_1_conv_3
    blocks6 = self.blocks
    up_1_conv_2 = blocks6.up_1_conv_2
    blocks7 = self.blocks
    up_1_conv_1 = blocks7.up_1_conv_1
    blocks8 = self.blocks
    up_1_conv_0 = blocks8.up_1_conv_0
    blocks9 = self.blocks
    up_2_upsamp = blocks9.up_2_upsamp
    blocks10 = self.blocks
    up_2_conv_3 = blocks10.up_2_conv_3
    blocks11 = self.blocks
    up_2_conv_2 = blocks11.up_2_conv_2
    blocks12 = self.blocks
    up_2_conv_1 = blocks12.up_2_conv_1
    blocks13 = self.blocks
    up_2_conv_0 = blocks13.up_2_conv_0
    blocks14 = self.blocks
    up_3_upsamp = blocks14.up_3_upsamp
    blocks15 = self.blocks
    up_3_conv_3 = blocks15.up_3_conv_3
    blocks16 = self.blocks
    up_3_conv_2 = blocks16.up_3_conv_2
    blocks17 = self.blocks
    up_3_conv_1 = blocks17.up_3_conv_1
    blocks18 = self.blocks
    up_3_conv_0 = blocks18.up_3_conv_0
    blocks19 = self.blocks
    mid_conv_1 = blocks19.mid_conv_1
    blocks20 = self.blocks
    mid_conv_0 = blocks20.mid_conv_0
    blocks21 = self.blocks
    down_3_conv_2 = blocks21.down_3_conv_2
    blocks22 = self.blocks
    down_3_conv_1 = blocks22.down_3_conv_1
    blocks23 = self.blocks
    down_3_conv_0 = blocks23.down_3_conv_0
    blocks24 = self.blocks
    down_2_downsamp = blocks24.down_2_downsamp
    blocks25 = self.blocks
    down_2_conv_2 = blocks25.down_2_conv_2
    blocks26 = self.blocks
    down_2_conv_1 = blocks26.down_2_conv_1
    blocks27 = self.blocks
    down_2_conv_0 = blocks27.down_2_conv_0
    blocks28 = self.blocks
    down_1_downsamp = blocks28.down_1_downsamp
    blocks29 = self.blocks
    down_1_conv_2 = blocks29.down_1_conv_2
    blocks30 = self.blocks
    down_1_conv_1 = blocks30.down_1_conv_1
    blocks31 = self.blocks
    down_1_conv_0 = blocks31.down_1_conv_0
    blocks32 = self.blocks
    down_0_downsamp = blocks32.down_0_downsamp
    blocks33 = self.blocks
    down_0_conv_2 = blocks33.down_0_conv_2
    blocks34 = self.blocks
    down_0_conv_1 = blocks34.down_0_conv_1
    blocks35 = self.blocks
    down_0_conv_0 = blocks35.down_0_conv_0
    blocks36 = self.blocks
    embed_image = blocks36.embed_image
    blocks37 = self.blocks
    embed_time = blocks37.embed_time
    input = torch.to(features, torch.device("cuda:0"), 6)
    features0 = torch.upsample_nearest2d(input, None, [8., 8.])
    x0 = torch.cat([x, features0], 1)
    _0 = (embed_time).forward(t, )
    _1 = (embed_image).forward(x0, )
    _2 = (down_0_conv_0).forward(_1, _0, )
    _3 = (down_0_conv_1).forward(_2, _0, )
    _4 = (down_0_conv_2).forward(_3, _0, )
    _5 = (down_0_downsamp).forward(_4, _0, )
    _6 = (down_1_conv_0).forward(_5, _0, )
    _7 = (down_1_conv_1).forward(_6, _0, )
    _8 = (down_1_conv_2).forward(_7, _0, )
    _9 = (down_1_downsamp).forward(_8, _0, )
    _10 = (down_2_conv_0).forward(_9, _0, )
    _11 = (down_2_conv_1).forward(_10, _0, )
    _12 = (down_2_conv_2).forward(_11, _0, )
    _13 = (down_2_downsamp).forward(_12, _0, )
    _14 = (down_3_conv_0).forward(_13, _0, )
    _15 = (down_3_conv_1).forward(_14, _0, )
    _16 = (down_3_conv_2).forward(_15, _0, )
    _17 = (mid_conv_1).forward((mid_conv_0).forward(_16, _0, ), _0, )
    _18 = (up_3_conv_0).forward(_17, _16, _0, )
    _19 = (up_3_conv_1).forward(_18, _15, _0, )
    _20 = (up_3_conv_2).forward(_19, _14, _0, )
    _21 = (up_3_conv_3).forward(_20, _13, _0, )
    _22 = (up_2_conv_0).forward((up_3_upsamp).forward(_21, _0, ), _12, _0, )
    _23 = (up_2_conv_1).forward(_22, _11, _0, )
    _24 = (up_2_conv_2).forward(_23, _10, _0, )
    _25 = (up_2_conv_3).forward(_24, _9, _0, )
    _26 = (up_1_conv_0).forward((up_2_upsamp).forward(_25, _0, ), _8, _0, )
    _27 = (up_1_conv_1).forward(_26, _7, _0, )
    _28 = (up_1_conv_2).forward(_27, _6, _0, )
    _29 = (up_1_conv_3).forward(_28, _5, _0, )
    _30 = (up_0_conv_0).forward((up_1_upsamp).forward(_29, _0, ), _4, _0, )
    _31 = (up_0_conv_1).forward(_30, _3, _0, )
    _32 = (up_0_conv_2).forward(_31, _2, _0, )
    _33 = (up_0_conv_3).forward(_32, _1, _0, )
    return (output).forward(_33, )

and how the blocks are defined has to be found here (in the decoder/code/__torch__/torch/nn/modules/container.py):

class ModuleDict(Module):
  __parameters__ = []
  __buffers__ = []
  training : bool
  _is_full_backward_hook : Optional[bool]
  embed_image : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ImageEmbedding
  embed_time : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.TimestepEmbedding
  down_0_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ConvResblock
  down_0_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_9.ConvResblock
  down_0_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_15.ConvResblock
  down_0_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_21.ConvResblock
  down_1_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_28.ConvResblock
  down_1_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_34.ConvResblock
  down_1_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_40.ConvResblock
  down_1_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_46.ConvResblock
  down_2_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_53.ConvResblock
  down_2_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_59.ConvResblock
  down_2_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_65.ConvResblock
  down_2_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_71.ConvResblock
  down_3_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_77.ConvResblock
  down_3_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_83.ConvResblock
  down_3_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_89.ConvResblock
  mid_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_95.ConvResblock
  mid_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_101.ConvResblock
  up_3_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_108.ConvResblock
  up_3_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_115.ConvResblock
  up_3_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_122.ConvResblock
  up_3_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_129.ConvResblock
  up_3_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_135.ConvResblock
  up_2_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_142.ConvResblock
  up_2_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_149.ConvResblock
  up_2_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_156.ConvResblock
  up_2_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_163.ConvResblock
  up_2_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_169.ConvResblock
  up_1_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_176.ConvResblock
  up_1_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_183.ConvResblock
  up_1_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_190.ConvResblock
  up_1_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_197.ConvResblock
  up_1_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_203.ConvResblock
  up_0_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_210.ConvResblock
  up_0_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_217.ConvResblock
  up_0_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_224.ConvResblock
  up_0_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_231.ConvResblock
  output : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ImageUnembedding

And from there, you need to dig into various of python files to find Conv2d configuration etc.

BTW, this is just unzip the decoder.pt to inspect the underlying Python code.

Uploaded weights and "pseudo code" with correct hparams which contribute to the weight.

The code above looks very much like a conditional UNet with concat conditioning (except that latents are upscaled by 8x using nearest neighbor upsampling). So for a latent of 4x32x32, it would be upsampled to 4x256x256 and then concatenated with the noisy input (3x256x256), then it looks like a regular UNet.

@mrsteyk's code worked for me after some minor edits 👍

image

Yeah, I realised I messed up skip connections when I went to sleep. Ups originally didn’t have 4 non resizing ConvResblocks

commented

Thanks for this commit. Did you test tiled_decode? Or is it not possible for this model?