johnowhitaker / aiaiart

Course content and resources for the AIAIART course.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Explanation of img_to_tensor in lesson 7?

vedantroy opened this issue · comments

Hi! Thanks for making these notebooks open-source. I'm trying to rewrite your code from this notebook in jax for practice. Happy to submit a PR with the end result if you'd find that useful.

I was wondering if you could explain the following method from lesson 7:

def img_to_tensor(im):
  return torch.tensor(np.array(im.convert('RGB'))/255).permute(2, 0, 1).unsqueeze(0) * 2 - 1

I'm a bit confused about why we need the permute.
I'm also a bit confused since it seems like we can omit the unsqueeze etc. and the code still works fine.

I would really appreciate it if you had a moment to give a quick explanation.
Thanks!

Sure thing.
PIL (and many other libraries) represent an image as an array of values between 0 and 255. Each pixel gets three values - one for Red, one for Green and one for Blue. The 'shape' of this array is usually [HEIGHT, WIDTH, 3] - the 3 is the three color channels, so we might abstract this and say [H, W, C] (or [W, H, C] depending on the library).
In PyTorch we typically represent images slightly differently:

  • The channels usually come before the other dimensions. Hence permute(2, 0, 1), which take the last dimension ('C') and moves it before the other two.
  • Many operations expect a BATCH of images rather than a single image. For example, a batch of 8 images each 32px square would have shape [8, 3, 32, 32]. To keep things consistent, when dealing with a single image we often treat it like a batch of size 1. So given a tensor shape [3, 32, 32] we'd add an extra dimension with unsqueeze(0) to give a new shape [1, 3, 32, 32]. This way we can use the same code for this as we would for a larger batch. In many cases, you can get away without but it's nice to be consistent.
  • Finally, since we don't need to restrict ourselves to 8-bit integers we typically represent the R, G and B values as floats between 0 and 1 OR between -1 and 1 (the latter is preferred for certain kinds of tasks). Hence in this case we take the original values, divide by 255 (to make it 0-1) and then * 2 - 1 (to make it -1 to 1).

I'd be interested to see the Jax version - I've also been playing around with it a little. With JAX you might not worry too much about adding the batch dimension (.unsqueeze) since there you often write a function that operates on a single example and then vectorized it with vmap().

Thanks for the detailed explanation! This is super helpful.
One last question here: Do you have to do the permute(2, 0, 1)?

I.e., from Googling around, putting the channel first is useful if you are using transforms from the torchvision library, which I don't see any of in the code.

Are you doing the permute for convention, or is there a deeper reason?

It's a convention I follow throughout the course since most lessons use torchvision transforms and/or other operations which assume the channels come first. Even in lesson 7 there are things which expect channels first - for example, the convolutional layers in the UNet (torch.nn.Conv2d).