Gradient calculation for recurrent operators is wrong
matteosal opened this issue · comments
This script creates an RNN
operator and computes its input gradient 5 times for sequence lengths = 1, 2, 3, 4, 5. Then it shows each gradient element at a fixed sequence position for all the computed sequence lengths:
import mxnet as mx
from mxnet import autograd
import numpy as np
batch_size = 1
data_len = 5
input_size = 2
output_size = 3
param_shapes = {
'wx': [output_size, input_size],
'ws': [output_size, output_size],
'bx': [output_size],
'bs': [output_size]
}
fused_param_len = np.sum(
[np.prod(v) for v in param_shapes.values()]
)
shapes = {
'data': [data_len, batch_size, input_size],
'par': [fused_param_len],
's0': [1, batch_size, output_size]
}
sym = mx.symbol.RNN(
*[mx.symbol.Variable(name) for name in shapes.keys()],
state_size=output_size,
num_layers=1,
mode='rnn_tanh'
)
op = mx.ndarray.CachedOp(sym)
args = [mx.np.random.uniform(size=shape, ctx=mx.cpu()) for shape in shapes.values()]
def get_grad(seq_len):
input_data = args[0][:seq_len]
with autograd.record(train_mode=True):
input_data.attach_grad()
output = op(input_data, args[1], args[2], default_ctx=mx.cpu())
autograd.backward(output, head_grads=mx.np.ones([data_len, batch_size, output_size], ctx=mx.cpu()))
return input_data.grad
results = []
for i in range(1, 6):
print('**************')
print('Input gradient for sequence length = ' + str(i) + '\n')
results.append(get_grad(i))
print(results[-1])
print('\n')
for i in range(4):
print('++++++++++++++')
print('Element #' + str(i) + ' of all input gradients')
for j in range(i, 5):
print('sequence length: ' + str(j+1) + ': ' + str(results[j][i]))
# [print('sequence length: ' + str(i+1) + ': ' + str(grad[i])) for grad in results[i:]]
print('\n')
The output is:
**************
Input gradient for sequence length = 1
[[[0.14385478 0.05408207]]]
**************
Input gradient for sequence length = 2
[[[0.14385478 0.05408207]]
[[0.01706791 0.00660894]]]
**************
Input gradient for sequence length = 3
[[[0.14385478 0.05408207]]
[[0.01706791 0.00660894]]
[[0.0178871 0.00672178]]]
**************
Input gradient for sequence length = 4
[[[0.14385478 0.05408207]]
[[0.01706791 0.00660894]]
[[0.0178871 0.00672178]]
[[0.01958952 0.00729937]]]
**************
Input gradient for sequence length = 5
[[[0.14385478 0.05408207]]
[[0.01706791 0.00660894]]
[[0.0178871 0.00672178]]
[[0.01958952 0.00729937]]
[[0.02612576 0.00999804]]]
++++++++++++++
Element #0 of all input gradients
sequence length: 1: [[0.14385478 0.05408207]]
sequence length: 2: [[0.14385478 0.05408207]]
sequence length: 3: [[0.14385478 0.05408207]]
sequence length: 4: [[0.14385478 0.05408207]]
sequence length: 5: [[0.14385478 0.05408207]]
++++++++++++++
Element #1 of all input gradients
sequence length: 2: [[0.01706791 0.00660894]]
sequence length: 3: [[0.01706791 0.00660894]]
sequence length: 4: [[0.01706791 0.00660894]]
sequence length: 5: [[0.01706791 0.00660894]]
++++++++++++++
Element #2 of all input gradients
sequence length: 3: [[0.0178871 0.00672178]]
sequence length: 4: [[0.0178871 0.00672178]]
sequence length: 5: [[0.0178871 0.00672178]]
++++++++++++++
Element #3 of all input gradients
sequence length: 4: [[0.01958952 0.00729937]]
sequence length: 5: [[0.01958952 0.00729937]]
In the last 4 sections starting with ++++++++++++++
, it can be seen that gradient elements at the same sequence position are equal across all the 5 gradient computations with sequence length 1, 2, 3, 4, 5 (if they are long enough to have that element, e.g. gradient with sequence length 2 cannot have element 3 obviously). This means that RNN
behaves as if the presence of later elements in the sequence does not affect the gradient for earlier elements.
But this is clearly wrong, because by the nature of recurrent computations earlier elements in the sequence DO affect later ones, hence gradient elements at the same sequence position should change if the sequence length is different. With a longer input sequence having an additional element, the gradient of all earlier elements should get an additional contribution from the new element, changing their value.
This is not a direct comparison with a manual computation of the gradient, but pointing out this behavior is enough to conclude that the gradients computed by this op are wrong. I should also point out that this is happening for all other settings of the mode
parameter of the operator, not only mode='rnn_tanh'
.
@bgawrych @barry-jin @szha I think this deserves attention
Updating this to report that it behaves correctly with oneDNN