google / tangent

Source-to-Source Debuggable Derivatives in Pure Python

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unexpected result in LogSumExp gradient using Tangent package in Python

paupereira opened this issue · comments

Problem:

  • First implementation:

I'm trying to get Tangent to compute the gradient of a function that contains the following implementation of logsumexp:

import numpy as np
import tangent

def logsumexp(a):
    # a = a.reshape(-1)
    result = 0.0
    largest_in_a = a[0]
    a_shape = len(a)

    # numba is slow when using max or np.max, so re-implementing:
    for i in range(1, a_shape):
        if a[i] > largest_in_a:
            largest_in_a = a[i]

    for i in range(a_shape):
        result += np.exp(a[i] - largest_in_a)

    return np.log(result) + largest_in_a

I call tangent as follows:

x = np.array([1,2,3,4])
grad_logsumexp = tangent.grad(logsumexp)

And get the result

grad_logsumexp(x)
Out[100]: array([0, 0, 0, 0])

While the correct answer is

array([0.0320586 , 0.08714432, 0.23688282, 0.64391426])
  • Second implementation:

On the other hand, doing this works:

def logsumexp_naive(a):
        return np.log(np.sum(np.exp(a)))

grad_logsumexp_naive = tangent.grad(logsumexp_naive)
grad_logsumexp_naive(x)

Question:

What's going on with the first implementation?

Here's a simpler example of the (perhaps) same behaviour:

import numpy as np
import tangent

def sum_squares(x):
    res = 0
    for a in x:
        res = res + a **2
    return res

sum_squares(x)

sum_squares_grad = tangent.grad(sum_squares)
sum_squares_grad(x)  # out: array([0, 0, 0, 0, 0])

#------------------------------------------------------------

def sum_squares2(x):
    res = 0
    for a in range(len(x)):
        res = res + x[a]**2
    return res

sum_squares2(x)

sum_squares2_grad = tangent.grad(sum_squares2)
sum_squares2_grad(x)  # out: array([ 2,  4,  6,  8, 10])

#------------------------------------------------------------

def sum_squares3(x):
    return np.sum(x** 2)

sum_squares3(x)

sum_squares3_grad = tangent.grad(sum_squares3)
sum_squares3_grad(x)  # out: array([ 2,  4,  6,  8, 10])

Hey, the issue is that for a in x creates an implicit indexed access into x. Without a little help, Tangent doesn't know anything about this, and so doesn't know to propagate gradients through the indexing operation back to x. I've got a small PR I'll merge later today that desugars this implicit indexing into an explicit form. Once e.g. sum_squares is converted into sum_squares2, everything proceeds nicely. Will also add a few tests.