dmlc / minpy

NumPy interface with mixed backend execution

Home Page:https://minpy.readthedocs.io/en/latest/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Data dependent branches not working

chenyang-tao opened this issue · comments

Hi guys, following are some issues with data dependent branching I encountered lately.

  1. This exact replica of the data dependent branching code does not work as expected.
import minpy.numpy as np

num = 100
X = np.linspace(-1.0,1.0,num=num)
Y = np.zeros([num,])

if X<Y:
    Z = X + Y
else:
    Z = Y ** 2

# Only Z = X + Y is executed
  1. Things get more hairy if I want to take the gradients with branching.
import minpy
from minpy.core import grad

def foo(X):

    # This will raise an AttributeError
    if X<=0:
        Y = X**2
    else:
        Y = 0

    # And this will work
    if -X>=0:
        Y = X**2
    else:
        Y = 0
    
    return Y

foo_grad = grad(foo)

print foo_grad(-1.0) 

Hi,

The first example is ambiguous. If you replace the namespace with numpy, you will got:

Traceback (most recent call last):
  File "t1.py", line 8, in <module>
    if X < Y:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

, because both X and Y are arrays, the X < Y will return a bool array of true and false. So what do you mean by condition X < Y? It is np.all(X < Y) or np.any(X < Y)? The reason only X + Y is executed is because in minpy it will use the first element as the condition value (which is true). So I think the behavior is correct, except that we should also give the same error as numpy.

The second one is same as #139 , where mxnet's operators have poor support for scalar arguments. We will fix this in MXNet.

Thanks for the report!

Thanks for the reply. I am now even more confused with the first example. The following code is given as an example demonstrating features of minpy in README.md, highlighting ''... you freely use the if statement anyway you like.''

import minpy.numpy as np

x = ... # create x array
y = ... # create y array

if x < y:
    z = x + y
else:
    z = y ** 2

So if only the first element of the returned bool array will be used, then this example is misleading.

Looking forward to your fix with the scalar arguments issue, it causes a lot of troubles.

I see. The example is only for scalar, but not meaningful for arrays. To make it more appropriate for arrays, you should have:

import minpy.numpy as np

x = ...
y = ...

if x[0] < y[0]:
  z = x + y
else:
  z = y ** 2

@jermainewang Can you update the image in web-data? Thanks!