Error in `gradient(fun)` and `hessian(fun)` if `fun` contains `x.T()` on a non-wrt argument `x`
adtzlr opened this issue · comments
Andreas Dutzler commented
The gradient- and hessian- functions both throw an error for
d2Wdp2 = hessian(W, wrt=1, ntrax=2)(F, p, J)
d2WdJ2 = hessian(W, wrt=2, ntrax=2)(F, p, J)
but not for
d2WdF2 = hessian(W, wrt=0, ntrax=2)(F, p, J)
with
def W(F, p, J):
C = F.T() @ F
I1 = tr(C)
detF = det(F)
return (detF ** (-1 / 3) * I1 - 3) / 2 + 25 * (J - 1) ** 2 + p * (detF - J)
This is because for wrt>0
the variable F
is an array instead of a tensor.
Andreas Dutzler commented
The error may be prevented by
from tensortrax.math import transpose
def W(F, p, J):
C = transpose(F) @ F
# ...
but then the matrix product of Numpy contracts the wrong axes. Even more dangerous because no error is shown!
The only way to get the desired result is
from tensortrax.math import transpose, dot
def W(F, p, J):
C = dot(transpose(F), F)
# ...
Andreas Dutzler commented