Hessians for a list of `wrt`
adtzlr opened this issue · comments
Idea
Use hessian()
for the diagonal entries because this supports the sym
-argument. Mixed-partials are evaluated by hand. Only the upper triangle part of the mixed partials is evaluated.
import tensortrax as tr
import tensortrax.math as tm
import numpy as np
from copy import copy
def fun(F, p, J):
C = tm.dot(tm.transpose(F), F)
detF = tm.linalg.det(F)
return detF ** (-2/3) * (tm.trace(C) - 3) + (J - 1) ** 2 + p * (J - detF)
def hessians(fun, wrt, ntrax=0, sym=False, parallel=False):
def inner(*args, **kwargs):
out = []
for a, b in zip(*np.triu_indices(len(wrt))):
if a == b:
symlocal = False
if sym and hasattr(wrt[a], "size"):
symlocal = True
out.append(
tr.hessian(
fun, wrt=wrt[a], ntrax=ntrax, sym=symlocal, parallel=parallel
)(F, p, J)
)
else:
tensorargs = list(copy(args))
tensorargs[a] = tr.Tensor(args[a], ntrax=ntrax)
tensorargs[b] = tr.Tensor(args[b], ntrax=ntrax)
tensorargs[a].init(hessian=True, δx=True, Δx=False)
tensorargs[b].init(hessian=True, δx=False, Δx=True)
out.append(tr.Δδ(fun(*tensorargs, **kwargs)))
return out
return inner
F = np.eye(3)
p = np.array([5])
J = np.array([3])
h = hessians(fun, wrt=[1,2,0], ntrax=0, sym=True, parallel=True)(F, p, J)