apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Home Page:https://mxnet.apache.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Gradient calculation for recurrent operators is wrong

matteosal opened this issue · comments

This script creates an RNN operator and computes its input gradient 5 times for sequence lengths = 1, 2, 3, 4, 5. Then it shows each gradient element at a fixed sequence position for all the computed sequence lengths:

import mxnet as mx
from mxnet import autograd
import numpy as np

batch_size = 1
data_len = 5
input_size = 2
output_size = 3

param_shapes = {
	'wx': [output_size, input_size], 
	'ws': [output_size, output_size], 
	'bx': [output_size],
	'bs': [output_size]
}
fused_param_len = np.sum(
	[np.prod(v) for v in param_shapes.values()]
)

shapes = {
	'data': [data_len, batch_size, input_size], 
	'par': [fused_param_len], 
	's0': [1, batch_size, output_size]
}

sym = mx.symbol.RNN(
	*[mx.symbol.Variable(name) for name in shapes.keys()],
	state_size=output_size,
	num_layers=1,
	mode='rnn_tanh'
)
op = mx.ndarray.CachedOp(sym)

args = [mx.np.random.uniform(size=shape, ctx=mx.cpu()) for shape in shapes.values()]

def get_grad(seq_len):
	input_data = args[0][:seq_len]
	with autograd.record(train_mode=True):
		input_data.attach_grad()
		output = op(input_data, args[1], args[2], default_ctx=mx.cpu())
	autograd.backward(output, head_grads=mx.np.ones([data_len, batch_size, output_size], ctx=mx.cpu()))
	return input_data.grad

results = []
for i in range(1, 6):
	print('**************')
	print('Input gradient for sequence length = ' + str(i) + '\n')
	results.append(get_grad(i))
	print(results[-1])
	print('\n')

for i in range(4):
	print('++++++++++++++')
	print('Element #' + str(i) + ' of all input gradients')
	for j in range(i, 5):
		print('sequence length: ' + str(j+1) + ': ' + str(results[j][i]))
	# [print('sequence length: ' + str(i+1) + ': ' + str(grad[i])) for grad in results[i:]]
	print('\n')

The output is:

**************
Input gradient for sequence length = 1

[[[0.14385478 0.05408207]]]


**************
Input gradient for sequence length = 2

[[[0.14385478 0.05408207]]
 [[0.01706791 0.00660894]]]


**************
Input gradient for sequence length = 3

[[[0.14385478 0.05408207]]
 [[0.01706791 0.00660894]]
 [[0.0178871  0.00672178]]]


**************
Input gradient for sequence length = 4

[[[0.14385478 0.05408207]]
 [[0.01706791 0.00660894]]
 [[0.0178871  0.00672178]]
 [[0.01958952 0.00729937]]]


**************
Input gradient for sequence length = 5

[[[0.14385478 0.05408207]]
 [[0.01706791 0.00660894]]
 [[0.0178871  0.00672178]]
 [[0.01958952 0.00729937]]
 [[0.02612576 0.00999804]]]


++++++++++++++
Element #0 of all input gradients
sequence length: 1: [[0.14385478 0.05408207]]
sequence length: 2: [[0.14385478 0.05408207]]
sequence length: 3: [[0.14385478 0.05408207]]
sequence length: 4: [[0.14385478 0.05408207]]
sequence length: 5: [[0.14385478 0.05408207]]


++++++++++++++
Element #1 of all input gradients
sequence length: 2: [[0.01706791 0.00660894]]
sequence length: 3: [[0.01706791 0.00660894]]
sequence length: 4: [[0.01706791 0.00660894]]
sequence length: 5: [[0.01706791 0.00660894]]


++++++++++++++
Element #2 of all input gradients
sequence length: 3: [[0.0178871  0.00672178]]
sequence length: 4: [[0.0178871  0.00672178]]
sequence length: 5: [[0.0178871  0.00672178]]


++++++++++++++
Element #3 of all input gradients
sequence length: 4: [[0.01958952 0.00729937]]
sequence length: 5: [[0.01958952 0.00729937]]

In the last 4 sections starting with ++++++++++++++, it can be seen that gradient elements at the same sequence position are equal across all the 5 gradient computations with sequence length 1, 2, 3, 4, 5 (if they are long enough to have that element, e.g. gradient with sequence length 2 cannot have element 3 obviously). This means that RNN behaves as if the presence of later elements in the sequence does not affect the gradient for earlier elements.
But this is clearly wrong, because by the nature of recurrent computations earlier elements in the sequence DO affect later ones, hence gradient elements at the same sequence position should change if the sequence length is different. With a longer input sequence having an additional element, the gradient of all earlier elements should get an additional contribution from the new element, changing their value.

This is not a direct comparison with a manual computation of the gradient, but pointing out this behavior is enough to conclude that the gradients computed by this op are wrong. I should also point out that this is happening for all other settings of the mode parameter of the operator, not only mode='rnn_tanh'.

@bgawrych @barry-jin @szha I think this deserves attention

Updating this to report that it behaves correctly with oneDNN