Auto-diff of the root of a polynomial, with respect to some parameter defining the polynomial
ss1870 opened this issue · comments
Hi all,
I'm trying to use autograd to find the derivative of the root of a polynomial, with respect to a parameter defining that polynomial, when the root is found by numerical means. I've tried to show a simple implementation in the code below, where I use a 3rd order polynomial and a bisection algorithm to find the root.
The problem I'm having is that autograd returns a derivative of zero and warns me that the "Output seems independent of the input". This is in contrast to the numerical derivative producing a non-zero value that appears to be consistent with the plot of the function that I'm differentiating. I'd be incredibly grateful if someone could point out where I'm going wrong with this. Thanks in advance.
import autograd.numpy as np
from functools import partial
from scipy import optimize
from autograd import grad
# Define 3rd order polynomial function
def polynomial(x_in, a, b, c, d):
return a * x_in ** 3.0 + b * x_in ** 2.0 + c * x_in + d
# Define bisection root finding algorithm
def find_root(func, lb, ub):
i = 0 # iteration counter
tol = 1e-8
while np.abs(ub - lb) > tol:
fa = func(lb)
split = 0.5
mid = split * (lb + ub)
fc = func(mid)
if fa > 0 and fc < 0:
ub = mid
elif fa > 0 and fc > 0:
lb = mid
elif fa < 0 and fc < 0:
lb = mid
elif fa < 0 and fc > 0:
ub = mid
i = i + 1
return mid, i
# Check my root finder is working (verify with standard root finder)
poly_closure = partial(polynomial, a=2, b=2, c=3, d=-4)
root, numIter = find_root(poly_closure, -2, 2)
print("Myfuncroot = " + str(root) + ", and numIter = " + str(numIter))
root_verify, res = optimize.brentq(poly_closure, -2, 2, full_output=True)
print("BrentRoot = " + str(root_verify) + ", and numIter = " + str(res.iterations))
# Define function to differentiate
def func_to_diff(param):
poly_closure_in = partial(polynomial, a=2, b=param, c=3, d=-4)
root1, nIter = find_root(poly_closure_in, -2, 2)
return root1
# Numerical gradient (central finite difference) of wrapper
h = 1e-5
print("FD: dfd = " + str((func_to_diff(2.0 + h) - func_to_diff(2.0 - h)) / 2 / h))
# Try auto-diff of wrapper
grad_func = grad(func_to_diff)
print("AD: dfd = " + str(grad_func(2.0)))
We have a small piece of special machinary for differentiating through fixed point iterations, which may fix your issue (although I'm not 100% certain and haven't run your code yet). If you can express the iterative process with a call to autograd.misc.fixed_points.fixed_point
, you should get the correct derivative. The fixed_point
function isn't documented AFAIK, but there's an example usage in examples/fixed_points.py.