Open-source model definitions
patrickvonplaten opened this issue · comments
Hey @gabgoh,
Super cool that you're open-sourcing the consistency decoder of Dalle-3 with a MIT license ❤️
Any chance you can also add the model definitions of the torch.jit binary? Otherwise it'll be quite difficult to port the model to other libraries.
I can try to get a more human readable version of this pushed, but does decoder_consistency.ckpt.code
work for the time being? The complete model definition is in there.
It is not quite as readable, as you need to plumbing through varies wrappers, for example, you get at high-level:
class ConvUNetVAE(Module):
__parameters__ = []
__buffers__ = []
training : bool
_is_full_backward_hook : Optional[bool]
blocks : __torch__.torch.nn.modules.container.ModuleDict
def forward(self: __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ConvUNetVAE,
x: Tensor,
t: Tensor,
features: Tensor) -> Tensor:
blocks = self.blocks
output = blocks.output
blocks0 = self.blocks
up_0_conv_3 = blocks0.up_0_conv_3
blocks1 = self.blocks
up_0_conv_2 = blocks1.up_0_conv_2
blocks2 = self.blocks
up_0_conv_1 = blocks2.up_0_conv_1
blocks3 = self.blocks
up_0_conv_0 = blocks3.up_0_conv_0
blocks4 = self.blocks
up_1_upsamp = blocks4.up_1_upsamp
blocks5 = self.blocks
up_1_conv_3 = blocks5.up_1_conv_3
blocks6 = self.blocks
up_1_conv_2 = blocks6.up_1_conv_2
blocks7 = self.blocks
up_1_conv_1 = blocks7.up_1_conv_1
blocks8 = self.blocks
up_1_conv_0 = blocks8.up_1_conv_0
blocks9 = self.blocks
up_2_upsamp = blocks9.up_2_upsamp
blocks10 = self.blocks
up_2_conv_3 = blocks10.up_2_conv_3
blocks11 = self.blocks
up_2_conv_2 = blocks11.up_2_conv_2
blocks12 = self.blocks
up_2_conv_1 = blocks12.up_2_conv_1
blocks13 = self.blocks
up_2_conv_0 = blocks13.up_2_conv_0
blocks14 = self.blocks
up_3_upsamp = blocks14.up_3_upsamp
blocks15 = self.blocks
up_3_conv_3 = blocks15.up_3_conv_3
blocks16 = self.blocks
up_3_conv_2 = blocks16.up_3_conv_2
blocks17 = self.blocks
up_3_conv_1 = blocks17.up_3_conv_1
blocks18 = self.blocks
up_3_conv_0 = blocks18.up_3_conv_0
blocks19 = self.blocks
mid_conv_1 = blocks19.mid_conv_1
blocks20 = self.blocks
mid_conv_0 = blocks20.mid_conv_0
blocks21 = self.blocks
down_3_conv_2 = blocks21.down_3_conv_2
blocks22 = self.blocks
down_3_conv_1 = blocks22.down_3_conv_1
blocks23 = self.blocks
down_3_conv_0 = blocks23.down_3_conv_0
blocks24 = self.blocks
down_2_downsamp = blocks24.down_2_downsamp
blocks25 = self.blocks
down_2_conv_2 = blocks25.down_2_conv_2
blocks26 = self.blocks
down_2_conv_1 = blocks26.down_2_conv_1
blocks27 = self.blocks
down_2_conv_0 = blocks27.down_2_conv_0
blocks28 = self.blocks
down_1_downsamp = blocks28.down_1_downsamp
blocks29 = self.blocks
down_1_conv_2 = blocks29.down_1_conv_2
blocks30 = self.blocks
down_1_conv_1 = blocks30.down_1_conv_1
blocks31 = self.blocks
down_1_conv_0 = blocks31.down_1_conv_0
blocks32 = self.blocks
down_0_downsamp = blocks32.down_0_downsamp
blocks33 = self.blocks
down_0_conv_2 = blocks33.down_0_conv_2
blocks34 = self.blocks
down_0_conv_1 = blocks34.down_0_conv_1
blocks35 = self.blocks
down_0_conv_0 = blocks35.down_0_conv_0
blocks36 = self.blocks
embed_image = blocks36.embed_image
blocks37 = self.blocks
embed_time = blocks37.embed_time
input = torch.to(features, torch.device("cuda:0"), 6)
features0 = torch.upsample_nearest2d(input, None, [8., 8.])
x0 = torch.cat([x, features0], 1)
_0 = (embed_time).forward(t, )
_1 = (embed_image).forward(x0, )
_2 = (down_0_conv_0).forward(_1, _0, )
_3 = (down_0_conv_1).forward(_2, _0, )
_4 = (down_0_conv_2).forward(_3, _0, )
_5 = (down_0_downsamp).forward(_4, _0, )
_6 = (down_1_conv_0).forward(_5, _0, )
_7 = (down_1_conv_1).forward(_6, _0, )
_8 = (down_1_conv_2).forward(_7, _0, )
_9 = (down_1_downsamp).forward(_8, _0, )
_10 = (down_2_conv_0).forward(_9, _0, )
_11 = (down_2_conv_1).forward(_10, _0, )
_12 = (down_2_conv_2).forward(_11, _0, )
_13 = (down_2_downsamp).forward(_12, _0, )
_14 = (down_3_conv_0).forward(_13, _0, )
_15 = (down_3_conv_1).forward(_14, _0, )
_16 = (down_3_conv_2).forward(_15, _0, )
_17 = (mid_conv_1).forward((mid_conv_0).forward(_16, _0, ), _0, )
_18 = (up_3_conv_0).forward(_17, _16, _0, )
_19 = (up_3_conv_1).forward(_18, _15, _0, )
_20 = (up_3_conv_2).forward(_19, _14, _0, )
_21 = (up_3_conv_3).forward(_20, _13, _0, )
_22 = (up_2_conv_0).forward((up_3_upsamp).forward(_21, _0, ), _12, _0, )
_23 = (up_2_conv_1).forward(_22, _11, _0, )
_24 = (up_2_conv_2).forward(_23, _10, _0, )
_25 = (up_2_conv_3).forward(_24, _9, _0, )
_26 = (up_1_conv_0).forward((up_2_upsamp).forward(_25, _0, ), _8, _0, )
_27 = (up_1_conv_1).forward(_26, _7, _0, )
_28 = (up_1_conv_2).forward(_27, _6, _0, )
_29 = (up_1_conv_3).forward(_28, _5, _0, )
_30 = (up_0_conv_0).forward((up_1_upsamp).forward(_29, _0, ), _4, _0, )
_31 = (up_0_conv_1).forward(_30, _3, _0, )
_32 = (up_0_conv_2).forward(_31, _2, _0, )
_33 = (up_0_conv_3).forward(_32, _1, _0, )
return (output).forward(_33, )
and how the blocks
are defined has to be found here (in the decoder/code/__torch__/torch/nn/modules/container.py
):
class ModuleDict(Module):
__parameters__ = []
__buffers__ = []
training : bool
_is_full_backward_hook : Optional[bool]
embed_image : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ImageEmbedding
embed_time : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.TimestepEmbedding
down_0_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ConvResblock
down_0_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_9.ConvResblock
down_0_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_15.ConvResblock
down_0_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_21.ConvResblock
down_1_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_28.ConvResblock
down_1_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_34.ConvResblock
down_1_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_40.ConvResblock
down_1_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_46.ConvResblock
down_2_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_53.ConvResblock
down_2_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_59.ConvResblock
down_2_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_65.ConvResblock
down_2_downsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_71.ConvResblock
down_3_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_77.ConvResblock
down_3_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_83.ConvResblock
down_3_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_89.ConvResblock
mid_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_95.ConvResblock
mid_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_101.ConvResblock
up_3_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_108.ConvResblock
up_3_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_115.ConvResblock
up_3_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_122.ConvResblock
up_3_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_129.ConvResblock
up_3_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_135.ConvResblock
up_2_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_142.ConvResblock
up_2_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_149.ConvResblock
up_2_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_156.ConvResblock
up_2_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_163.ConvResblock
up_2_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_169.ConvResblock
up_1_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_176.ConvResblock
up_1_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_183.ConvResblock
up_1_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_190.ConvResblock
up_1_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_197.ConvResblock
up_1_upsamp : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_203.ConvResblock
up_0_conv_0 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_210.ConvResblock
up_0_conv_1 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_217.ConvResblock
up_0_conv_2 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_224.ConvResblock
up_0_conv_3 : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.___torch_mangle_231.ConvResblock
output : __torch__.dalle_runner_api.model_infra.modules.public_diff_vae.ImageUnembedding
And from there, you need to dig into various of python files to find Conv2d configuration etc.
BTW, this is just unzip the decoder.pt
to inspect the underlying Python code.
Uploaded weights and "pseudo code" with correct hparams which contribute to the weight.
The code above looks very much like a conditional UNet with concat conditioning (except that latents are upscaled by 8x using nearest neighbor upsampling). So for a latent of 4x32x32, it would be upsampled to 4x256x256 and then concatenated with the noisy input (3x256x256), then it looks like a regular UNet.
@mrsteyk's code worked for me after some minor edits 👍
Yeah, I realised I messed up skip connections when I went to sleep. Ups originally didn’t have 4 non resizing ConvResblocks
Thanks for this commit. Did you test tiled_decode
? Or is it not possible for this model?