Problem with `where` function
zarif98sjs opened this issue · comments
Shouldn't the where
function be this?
def where(q, a, b):
"Use this function to replace an if-statement."
return (q * a) + (torch.logical_not(q)) * b
Otherwise if we use ~q
, technically isn't that incorrect according to the desired function outcome?
If we used ~q
,
where(arange(4) * 0, 0, 1)
returns tensor([-1, -1, -1, -1])
.
But the desired output should be tensor([1, 1, 1, 1])
I agree. ~
is bitwise NOT. So the behavior is unexpected if q
is a list of integers.
Oops, will fix if I do a new version.
Ah, nice! When creating the issue I was wondering why nobody noticed all these years 😅 Can send a PR if you want
no I should just do a v2 this summer. lots of small fixes abound