TypeError with enumerate
stefdoerr opened this issue · comments
def test(x):
elem = -1
for r, xxx in enumerate(x):
elem = r
return 5
In [26]: xxx = tangent.grad(test)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-26-e1ef2738f728> in <module>()
----> 1 xxx = tangent.grad(test)
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in grad(func, wrt, optimized, preserve_result, check_dims, verbose)
384 check_dims=check_dims,
385 input_derivative=INPUT_DERIVATIVE.DefaultOne,
--> 386 verbose=verbose)
387
388
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in autodiff(func, wrt, optimized, motion, mode, preserve_result, check_dims, input_derivative, verbose)
288 # Generate the derivative
289 node, namespace = autodiff_tree(func, wrt, motion, mode, preserve_result,
--> 290 check_dims, verbose)
291
292 if mode == 'reverse' and motion == 'joint':
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in autodiff_tree(func, wrt, motion, mode, preserve_result, check_dims, verbose)
142
143 node, required = autodiff_ast(func, wrt, motion, mode, preserve_result,
--> 144 check_dims, verbose)
145 final.body.extend(node.body)
146
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in autodiff_ast(func, wrt, motion, mode, preserve_result, check_dims, verbose)
95 if mode == 'reverse':
96 node, required, stack = reverse_ad.reverse_ad(node.body[0], wrt,
---> 97 preserve_result, check_dims)
98 if verbose >= 2:
99 print('RAW')
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in reverse_ad(node, wrt, preserve_result, check_dims)
841
842 ad = ReverseAD(wrt, preserve_result, check_dims)
--> 843 pri, adj = ad.visit(node)
844 mod = gast.Module(body=[pri, adj])
845 mod = annotate.find_stacks(mod)
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit(self, node)
149 if anno.hasanno(node, 'active_in'):
150 self.active_variables = anno.getanno(node, 'active_in')
--> 151 pri, adj = visitor(node)
152
153 # Annotate primal and adjoint statements
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit_FunctionDef(self, node)
212
213 # Perform AD on the function body
--> 214 body, adjoint_body = self.visit_statements(node.body[:-1])
215
216 # Annotate the first statement of the primal and adjoint as such
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit_statements(self, nodes)
285 primals, adjoints = [], collections.deque()
286 for node in nodes:
--> 287 primal, adjoint = self.visit(node)
288 if not isinstance(primal, list):
289 primal = [primal]
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit(self, node)
149 if anno.hasanno(node, 'active_in'):
150 self.active_variables = anno.getanno(node, 'active_in')
--> 151 pri, adj = visitor(node)
152
153 # Annotate primal and adjoint statements
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit_For(self, node)
311 # temporarily set aside each iteration to push to the stack later
312 push_target, pop_target, op_id_target = get_push_pop()
--> 313 tmp_target = create.create_temp(node.target, self.namer)
314
315 primal_template = grads.primals[gast.For]
/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/create.py in create_temp(node, namer)
114 name = node.value.id
115 else:
--> 116 raise TypeError
117 temp_node = gast.Name(id=namer.temp(name), annotation=None, ctx=None)
118 anno.setanno(temp_node, 'temp_var', node)
TypeError:
Enumerate is not well supported right now (though we should have raised an error about that, so it's a bug either way).
Using a regular walk for i in range(len(x))
should work though.