Data dependent branches not working
chenyang-tao opened this issue · comments
Hi guys, following are some issues with data dependent branching I encountered lately.
- This exact replica of the data dependent branching code does not work as expected.
import minpy.numpy as np
num = 100
X = np.linspace(-1.0,1.0,num=num)
Y = np.zeros([num,])
if X<Y:
Z = X + Y
else:
Z = Y ** 2
# Only Z = X + Y is executed
- Things get more hairy if I want to take the gradients with branching.
import minpy
from minpy.core import grad
def foo(X):
# This will raise an AttributeError
if X<=0:
Y = X**2
else:
Y = 0
# And this will work
if -X>=0:
Y = X**2
else:
Y = 0
return Y
foo_grad = grad(foo)
print foo_grad(-1.0)
Hi,
The first example is ambiguous. If you replace the namespace with numpy, you will got:
Traceback (most recent call last):
File "t1.py", line 8, in <module>
if X < Y:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
, because both X and Y are arrays, the X < Y
will return a bool array of true and false. So what do you mean by condition X < Y
? It is np.all(X < Y)
or np.any(X < Y)
? The reason only X + Y is executed is because in minpy it will use the first element as the condition value (which is true). So I think the behavior is correct, except that we should also give the same error as numpy.
The second one is same as #139 , where mxnet's operators have poor support for scalar arguments. We will fix this in MXNet.
Thanks for the report!
Thanks for the reply. I am now even more confused with the first example. The following code is given as an example demonstrating features of minpy in README.md, highlighting ''... you freely use the if statement anyway you like.''
import minpy.numpy as np
x = ... # create x array
y = ... # create y array
if x < y:
z = x + y
else:
z = y ** 2
So if only the first element of the returned bool array will be used, then this example is misleading.
Looking forward to your fix with the scalar arguments issue, it causes a lot of troubles.
I see. The example is only for scalar, but not meaningful for arrays. To make it more appropriate for arrays, you should have:
import minpy.numpy as np
x = ...
y = ...
if x[0] < y[0]:
z = x + y
else:
z = y ** 2
@jermainewang Can you update the image in web-data? Thanks!