microsoft / nn-Meter

A DNN inference latency prediction toolkit for accurately modeling and predicting the latency on diverse edge devices.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Question on Padding Calculation for Building PyTorch Networks

WingsUpete opened this issue · comments

Hi. I am confused about this tiny piece of utility function:

According to PyTorch:
$$h_o = \lfloor \frac{h_i + 2p - k}{s} + 1 \rfloor,$$
where $h_i, h_o$ specify the input and output height, $p$ specifies the padding, $k$ specifies the kernel size, and $s$ specifies the stride. The equation can be transformed as:
$$p \rightarrow \frac{s(h_o - 1) + k - h_i}{2} = \frac{s h_o - h_i + k - s}{2} = \frac{k - (h_i \% s) - s}{2}.$$

From my understanding, the comment in the function implies that $s h_o$ and $h_i$ should be close to each other ($h_o = h_i // s$) in order to avoid large calculated padding results. I also understand that due to the upper bound, it will get complicated to get the exact $p$ to meet the requirement. My question here is about the code implementation of [this line]:(

):

Shouldn't the calculation subtract an extra $s$ (i.e., pad = max(ks - (hw % s) - s, 0))?

For confirmation, here is some test code with example ($k=11, s=2, h=3$):

>>> k = 11
>>> s = 2
>>> h = 3
# test the function 
>>> p0 = get_padding(ks=k, s=s, hw=h)
>>> p0
5
# calculate directly using the inferred equation
>>> p1 = (k - (h % s) - s) / 2
>>> p1
4.0
# for confirmation, calculate h_o with p0 and p1 (expected to be h // s)
>>> (h + 2 * p0 - k) / s + 1
2.0
>>> (h + 2 * p1 - k) / s + 1
1.0
>>> h // s
1

Hi,

Thank you for your thoughtful question! We use this padding method to ensure that the output shape is consistent with the same situation for TensorFlow. In TensorFlow, tf.keras.layers.Conv2D has a parameter "padding" with choice to be as 'same'. However, for PyTorch, padding='same' doesn’t support any stride values other than 1. Therefore, we wrote this padding code, which produces the same output shape as in TensorFlow for the same configuration.

As for the example you mentioned, in Tensorflow, the output shape is 2

import tensorflow as tf
from tensorflow import keras

inputs = keras.Input(shape=[3, 3, 128], batch_size=1)
layer = keras.layers.Conv2D(128, kernel_size=11, strides=2, padding="same")
print(layer(inputs).shape)

Thanks,
Jiahang

Hi,

Thank you for your thoughtful question! We use this padding method to ensure that the output shape is consistent with the same situation for TensorFlow. In TensorFlow, tf.keras.layers.Conv2D has a parameter "padding" with choice to be as 'same'. However, for PyTorch, padding='same' doesn’t support any stride values other than 1. Therefore, we wrote this padding code, which produces the same output shape as in TensorFlow for the same configuration.

As for the example you mentioned, in Tensorflow, the output shape is 2

import tensorflow as tf
from tensorflow import keras

inputs = keras.Input(shape=[3, 3, 128], batch_size=1)
layer = keras.layers.Conv2D(128, kernel_size=11, strides=2, padding="same")
print(layer(inputs).shape)

Thanks, Jiahang

Thank you for the reply. The fact that the idea came from Tensorflow is very helpful. For confirmation, I went to check the code from Tensorflow and found this implementation. It seems like output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i]) actually refers to $h // s + 1$, which is different from the docstring in your function:

if s = 2, out_hw = in_hw // 2;

I would recommend changing the documentation to avoid misunderstanding, though it is trivial.

Again, thanks for the quick reply!

Thanks for your advice! I will refine the doc soon to make it clear. 😃