google / tangent

Source-to-Source Debuggable Derivatives in Pure Python

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError with enumerate

stefdoerr opened this issue · comments

def test(x):
    elem = -1
    for r, xxx in enumerate(x):
        elem = r
    return 5
In [26]: xxx = tangent.grad(test)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-26-e1ef2738f728> in <module>()
----> 1 xxx = tangent.grad(test)

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in grad(func, wrt, optimized, preserve_result, check_dims, verbose)
    384       check_dims=check_dims,
    385       input_derivative=INPUT_DERIVATIVE.DefaultOne,
--> 386       verbose=verbose)
    387 
    388 

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in autodiff(func, wrt, optimized, motion, mode, preserve_result, check_dims, input_derivative, verbose)
    288   # Generate the derivative
    289   node, namespace = autodiff_tree(func, wrt, motion, mode, preserve_result,
--> 290                                   check_dims, verbose)
    291 
    292   if mode == 'reverse' and motion == 'joint':

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in autodiff_tree(func, wrt, motion, mode, preserve_result, check_dims, verbose)
    142 
    143   node, required = autodiff_ast(func, wrt, motion, mode, preserve_result,
--> 144                                 check_dims, verbose)
    145   final.body.extend(node.body)
    146 

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/grad_util.py in autodiff_ast(func, wrt, motion, mode, preserve_result, check_dims, verbose)
     95   if mode == 'reverse':
     96     node, required, stack = reverse_ad.reverse_ad(node.body[0], wrt,
---> 97                                                   preserve_result, check_dims)
     98     if verbose >= 2:
     99       print('RAW')

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in reverse_ad(node, wrt, preserve_result, check_dims)
    841 
    842   ad = ReverseAD(wrt, preserve_result, check_dims)
--> 843   pri, adj = ad.visit(node)
    844   mod = gast.Module(body=[pri, adj])
    845   mod = annotate.find_stacks(mod)

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit(self, node)
    149     if anno.hasanno(node, 'active_in'):
    150       self.active_variables = anno.getanno(node, 'active_in')
--> 151     pri, adj = visitor(node)
    152 
    153     # Annotate primal and adjoint statements

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit_FunctionDef(self, node)
    212 
    213     # Perform AD on the function body
--> 214     body, adjoint_body = self.visit_statements(node.body[:-1])
    215 
    216     # Annotate the first statement of the primal and adjoint as such

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit_statements(self, nodes)
    285     primals, adjoints = [], collections.deque()
    286     for node in nodes:
--> 287       primal, adjoint = self.visit(node)
    288       if not isinstance(primal, list):
    289         primal = [primal]

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit(self, node)
    149     if anno.hasanno(node, 'active_in'):
    150       self.active_variables = anno.getanno(node, 'active_in')
--> 151     pri, adj = visitor(node)
    152 
    153     # Annotate primal and adjoint statements

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/reverse_ad.py in visit_For(self, node)
    311     # temporarily set aside each iteration to push to the stack later
    312     push_target, pop_target, op_id_target = get_push_pop()
--> 313     tmp_target = create.create_temp(node.target, self.namer)
    314 
    315     primal_template = grads.primals[gast.For]

/shared/sdoerr/Software/miniconda3/lib/python3.6/site-packages/tangent/create.py in create_temp(node, namer)
    114     name = node.value.id
    115   else:
--> 116     raise TypeError
    117   temp_node = gast.Name(id=namer.temp(name), annotation=None, ctx=None)
    118   anno.setanno(temp_node, 'temp_var', node)

TypeError: 

Enumerate is not well supported right now (though we should have raised an error about that, so it's a bug either way).
Using a regular walk for i in range(len(x)) should work though.