adtzlr / tensortrax

Differentiable Tensors based on NumPy Arrays

Home Page:https://tensortrax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error in `gradient(fun)` and `hessian(fun)` if `fun` contains `x.T()` on a non-wrt argument `x`

adtzlr opened this issue · comments

The gradient- and hessian- functions both throw an error for

d2Wdp2 = hessian(W, wrt=1, ntrax=2)(F, p, J)
d2WdJ2 = hessian(W, wrt=2, ntrax=2)(F, p, J)

but not for

d2WdF2 = hessian(W, wrt=0, ntrax=2)(F, p, J)

with

def W(F, p, J):
    C = F.T() @ F
    I1 = tr(C)
    detF = det(F)
    return (detF ** (-1 / 3) * I1 - 3) / 2 + 25 * (J - 1) ** 2 + p * (detF - J)

This is because for wrt>0 the variable F is an array instead of a tensor.

The error may be prevented by

from tensortrax.math import transpose

def W(F, p, J): 
    C = transpose(F) @ F
    # ...

but then the matrix product of Numpy contracts the wrong axes. Even more dangerous because no error is shown!

The only way to get the desired result is

from tensortrax.math import transpose, dot

def W(F, p, J): 
    C = dot(transpose(F), F)
    # ...