adtzlr / tensortrax

Differentiable Tensors based on NumPy Arrays

Home Page:https://tensortrax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Getting and Setting of Tensor-items is broken

adtzlr opened this issue · comments

Getting and setting of tensor-items needs some re-work on v0.6.0, due to #58.

import tensortrax as tr
import tensortrax.math as tm
import numpy as np

x = np.tile(np.arange(9).reshape(3, 3, 1, 1), (1, 1, 6, 7))
t = tr.Tensor(x, ntrax=2)
t.init(gradient=True)

print("Shape of Tensor data")
print("====================")
print(" f(t).shape =", tr.f(t).shape)
print(" δ(t).shape =", tr.δ(t).shape)
print(" Δ(t).shape =", tr.Δ(t).shape)
print("Δδ(t).shape =", tr.Δδ(t).shape)

mask = tr.f(t) > 5

# can't get selected items because the trailing axes have to be broadcasted first
try:
    t[mask]
except:
    print("Error: get-item failed.")

# the solution is to broadcast first
shape = np.maximum.reduce([t.x.shape, t.δx.shape, t.Δx.shape, t.Δδx.shape])
t.x = np.broadcast_to(t.x, shape)
t.δx = np.broadcast_to(t.δx, shape)
t.Δx = np.broadcast_to(t.Δx, shape)
t.Δδx = np.broadcast_to(t.Δδx, shape)

# also broadcast the mask
Mask = np.broadcast_to(mask, shape)

t[Mask]

Getting (slicing) and Setting (item assignment) is supported but likely will fail due to different shapes of real and dual data.