johnowhitaker / aiaiart

Course content and resources for the AIAIART course.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Incorrect annotation of shapes in Unet in lesson #7?

vedantroy opened this issue · comments

Hi! It's me again.

I'm creating an annotated version of the UNet in lesson #7 (diffusion models). I'm adding more comments + assertions for the shapes of all inputs/outputs/weights/intermediate steps.

While doing this, I noticed there might be a mistake in some of the comments?

Here's the code that runs the UNet on dummy data (from the lesson):

# A dummy batch of 10 3-channel 32px images
x = torch.randn(10, 3, 32, 32)

# 't' - what timestep are we on
t = torch.tensor([50], dtype=torch.long)

# Define the unet model
unet = UNet()

# The foreward pass (takes both x and t)
model_output = unet(x, t)

Inside the actual UNet this is the forwad pass

    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        * `x` has shape `[batch_size, in_channels, height, width]`
        * `t` has shape `[batch_size]`
        """

        # Get time-step embeddings
        t = self.time_emb(t)

It says that the shape of t is [batch_size]. But the shape of t is 1, which is to be expected if we look at the code that is testing the UNet.

Specifically, the assertion:

        batch_size = x.shape[0]
        print(t.shape)
        assert t.shape[0] == batch_size

fails.

I'm not sure exactly what's going on here. My hypothesis is as follows: The UNet is being trained on a batch of images. Each image in the batch should be accompanied by its own time step number. However, it looks like only a single time-step is being passed into the UNet.

Somewhere along the line, this time-step is being accidentally broad-casted by Pytorch to fit the batch dimension and being used as the time-step for all images.

Does that sound correct to you?

I think the accidental broadcast is happening here (code is messy b/c I added my own annotations, but it's inside of ResidualBlock):

        batch_size = 1

        # First convolution layer
        h = self.conv1(self.act1(self.norm1(x)))

        time_emb = self.time_emb(t)
        assert t.shape == (batch_size, self.time_channels)
        assert time_emb.shape == (batch_size, self.out_channels)
        time_emb = time_emb[:, :, None, None]
        assert time_emb.shape == (batch_size, self.out_channels, 1, 1)
        # This looks like:
        # [ [[a]], [[b]], [[c]], [[d]], [[e]], [[f]] ]
        # when self.out_channels = 6

In general, you'd have a separate random t for each image (so it would be shape [batch_size]). But for the demo and during sampling it's the same t for the whole batch, so it's convenient to also accept a single value. An alternative would be to force the right shape (as you're doing with the assert) and tweak the sampling code to pass in t as a tensor of shape [batch_size] instead of [1].
During training, t has the shape described:
t = torch.randint(0, n_steps, (batch_size,), dtype=torch.long).cuda()

(I could be mistaken on this, will take a look at the code in more depth when I have a bit more time)

Got it! This make sense. You allow the NN to accept a single value of t, which makes the assumption that all images are at the same timestep. This makes it more convenient to use. Sounds good!

No need to look through the code, don't want to waste your time!