Explanation of img_to_tensor in lesson 7?
vedantroy opened this issue · comments
Hi! Thanks for making these notebooks open-source. I'm trying to rewrite your code from this notebook in jax for practice. Happy to submit a PR with the end result if you'd find that useful.
I was wondering if you could explain the following method from lesson 7:
def img_to_tensor(im):
return torch.tensor(np.array(im.convert('RGB'))/255).permute(2, 0, 1).unsqueeze(0) * 2 - 1
I'm a bit confused about why we need the permute
.
I'm also a bit confused since it seems like we can omit the unsqueeze
etc. and the code still works fine.
I would really appreciate it if you had a moment to give a quick explanation.
Thanks!
Sure thing.
PIL (and many other libraries) represent an image as an array of values between 0 and 255. Each pixel gets three values - one for Red, one for Green and one for Blue. The 'shape' of this array is usually [HEIGHT, WIDTH, 3] - the 3 is the three color channels, so we might abstract this and say [H, W, C] (or [W, H, C] depending on the library).
In PyTorch we typically represent images slightly differently:
- The channels usually come before the other dimensions. Hence
permute(2, 0, 1)
, which take the last dimension ('C') and moves it before the other two. - Many operations expect a BATCH of images rather than a single image. For example, a batch of 8 images each 32px square would have shape [8, 3, 32, 32]. To keep things consistent, when dealing with a single image we often treat it like a batch of size 1. So given a tensor shape [3, 32, 32] we'd add an extra dimension with
unsqueeze(0)
to give a new shape [1, 3, 32, 32]. This way we can use the same code for this as we would for a larger batch. In many cases, you can get away without but it's nice to be consistent. - Finally, since we don't need to restrict ourselves to 8-bit integers we typically represent the R, G and B values as floats between 0 and 1 OR between -1 and 1 (the latter is preferred for certain kinds of tasks). Hence in this case we take the original values, divide by 255 (to make it 0-1) and then * 2 - 1 (to make it -1 to 1).
I'd be interested to see the Jax version - I've also been playing around with it a little. With JAX you might not worry too much about adding the batch dimension (.unsqueeze) since there you often write a function that operates on a single example and then vectorized it with vmap().
Thanks for the detailed explanation! This is super helpful.
One last question here: Do you have to do the permute(2, 0, 1)
?
I.e., from Googling around, putting the channel first is useful if you are using transforms from the torchvision library, which I don't see any of in the code.
Are you doing the permute
for convention, or is there a deeper reason?
It's a convention I follow throughout the course since most lessons use torchvision transforms and/or other operations which assume the channels come first. Even in lesson 7 there are things which expect channels first - for example, the convolutional layers in the UNet (torch.nn.Conv2d).