Unexpected result in LogSumExp gradient using Tangent package in Python
paupereira opened this issue · comments
Problem:
- First implementation:
I'm trying to get Tangent to compute the gradient of a function that contains the following implementation of logsumexp:
import numpy as np
import tangent
def logsumexp(a):
# a = a.reshape(-1)
result = 0.0
largest_in_a = a[0]
a_shape = len(a)
# numba is slow when using max or np.max, so re-implementing:
for i in range(1, a_shape):
if a[i] > largest_in_a:
largest_in_a = a[i]
for i in range(a_shape):
result += np.exp(a[i] - largest_in_a)
return np.log(result) + largest_in_a
I call tangent as follows:
x = np.array([1,2,3,4])
grad_logsumexp = tangent.grad(logsumexp)
And get the result
grad_logsumexp(x)
Out[100]: array([0, 0, 0, 0])
While the correct answer is
array([0.0320586 , 0.08714432, 0.23688282, 0.64391426])
- Second implementation:
On the other hand, doing this works:
def logsumexp_naive(a):
return np.log(np.sum(np.exp(a)))
grad_logsumexp_naive = tangent.grad(logsumexp_naive)
grad_logsumexp_naive(x)
Question:
What's going on with the first implementation?
Here's a simpler example of the (perhaps) same behaviour:
import numpy as np
import tangent
def sum_squares(x):
res = 0
for a in x:
res = res + a **2
return res
sum_squares(x)
sum_squares_grad = tangent.grad(sum_squares)
sum_squares_grad(x) # out: array([0, 0, 0, 0, 0])
#------------------------------------------------------------
def sum_squares2(x):
res = 0
for a in range(len(x)):
res = res + x[a]**2
return res
sum_squares2(x)
sum_squares2_grad = tangent.grad(sum_squares2)
sum_squares2_grad(x) # out: array([ 2, 4, 6, 8, 10])
#------------------------------------------------------------
def sum_squares3(x):
return np.sum(x** 2)
sum_squares3(x)
sum_squares3_grad = tangent.grad(sum_squares3)
sum_squares3_grad(x) # out: array([ 2, 4, 6, 8, 10])
Hey, the issue is that for a in x
creates an implicit indexed access into x
. Without a little help, Tangent doesn't know anything about this, and so doesn't know to propagate gradients through the indexing operation back to x
. I've got a small PR I'll merge later today that desugars this implicit indexing into an explicit form. Once e.g. sum_squares
is converted into sum_squares2
, everything proceeds nicely. Will also add a few tests.