joschu / cgt

Computation Graph Toolkit

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error on tutorial

avivt opened this issue · comments

I get an assertion error (assert newnewnode.typ == orig.typ) on the simplify command in the tutorial code:
cgt.print_expr(cgt.simplify([dLdw])[0]);

I am running on windows, with only the python installed (no Cython or Cuda installed).

This is the code I am running:
import cgt
a = cgt.scalar(name='a') # float-valued scalar, with optional name provided
b = cgt.scalar(name='b')
n = cgt.scalar(name='n', dtype='int64') # integer scalar

c = (an + bn)**(1.0/n)

f = cgt.function([a,b,n], c)
print f(8,15,2)

X_nk = cgt.matrix("X")
y_n = cgt.vector("y")
w_k = cgt.vector("w")
b = cgt.scalar("b")
ypred_n = X_nk.dot(w_k) + b
L = cgt.sum(cgt.square(ypred_n - y_n))
print "L = ",
cgt.print_expr(L)
print X_nk.ndim, str(X_nk.shape), X_nk.dtype
grads = dLdw, dLdb = cgt.grad(L, [w_k, b])
print "Loss and gradient objects", dLdw, dLdb
print "Pretty-printed gradient: ",
cgt.print_expr(cgt.simplify([dLdw])[0]);

And this is the error I get:

Traceback (most recent call last):
File "C:/test_cgt.py", line 23, in
cgt.print_expr(cgt.simplify([dLdw])[0]);
File "C:\CGT\cgt\core.py", line 2688, in simplify
return simplify_and_analyze(xs)[0]
File "C:\CGT\cgt\core.py", line 2533, in simplify_and_analyze
for output in outputs: update_simplify_map(output, analysis, repl)
File "C:\CGT\cgt\core.py", line 2600, in update_simplify_map
maybe_pair = process_top_stack_item_and_maybe_get_replacement(stack, analysis, repl)
File "C:\CGT\cgt\core.py", line 2567, in process_top_stack_item_and_maybe_get_replacement
assert newnewnode.typ == orig.typ
AssertionError

I also put a break-point on the assertion line, and got that
newnewnode.typ = Tensor(i4,0)
orig.typ = Tensor(i8,0)