adtzlr / tensortrax

Differentiable Tensors based on NumPy Arrays

Home Page:https://tensortrax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Hessians for a list of `wrt`

adtzlr opened this issue · comments

Idea

Use hessian() for the diagonal entries because this supports the sym-argument. Mixed-partials are evaluated by hand. Only the upper triangle part of the mixed partials is evaluated.

import tensortrax as tr
import tensortrax.math as tm
import numpy as np
from copy import copy

def fun(F, p, J):
    C = tm.dot(tm.transpose(F), F)
    detF = tm.linalg.det(F)
    return detF ** (-2/3) * (tm.trace(C) - 3) + (J - 1) ** 2 + p * (J - detF)

def hessians(fun, wrt, ntrax=0, sym=False, parallel=False):
    
    def inner(*args, **kwargs):
        
        out = []
        for a, b in zip(*np.triu_indices(len(wrt))):
            if a == b:
                symlocal = False
                if sym and hasattr(wrt[a], "size"):
                    symlocal = True
                out.append(
                    tr.hessian(
                        fun, wrt=wrt[a], ntrax=ntrax, sym=symlocal, parallel=parallel
                    )(F, p, J)
                )
            else:
                tensorargs = list(copy(args))
                
                tensorargs[a] = tr.Tensor(args[a], ntrax=ntrax)
                tensorargs[b] = tr.Tensor(args[b], ntrax=ntrax)
                
                tensorargs[a].init(hessian=True, δx=True, Δx=False)
                tensorargs[b].init(hessian=True, δx=False, Δx=True)
                
                out.append(tr.Δδ(fun(*tensorargs, **kwargs)))
            
        return out
    
    return inner

F = np.eye(3)
p = np.array([5])
J = np.array([3])

h = hessians(fun, wrt=[1,2,0], ntrax=0, sym=True, parallel=True)(F, p, J)