PaddlePaddle / Paddle

PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)

Home Page:http://www.paddlepaddle.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`masked_fill_`對int64處理異常,塞入paddle.iinfo(paddle.int64).max會被當作min

anderson101866 opened this issue · comments

bug描述 Describe the Bug

如下單純的script,帶入int64最大值時,會變成最小值

import paddle; print(paddle.__version__) #2.6.0

t = paddle.zeros((2,2), dtype=paddle.int64)
print(paddle.iinfo(paddle.int64).max == 2**63-1) #True
t.masked_fill_(paddle.to_tensor([[0,0],[0,1]]), 2**63-1) #<--------------
print(t) 
#Tensor(shape=[2, 2], dtype=int64, place=Place(gpu:1), stop_gradient=True,
#       [[ 0                  ,  0                  ],
#        [ 0                  , -9223372036854775808]])

其他补充信息 Additional Supplementary Information

No response

@AndSonder 这个问题能否辛苦看下呢?

OP来源:#57355

@zyfncg 看起来是 full op 里面的 bug,如下代码会产生溢出

>>> paddle.full([], 2**63-1, paddle.int64)
Tensor(shape=[], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       -9223372036854775808)

@anderson101866 您可以通过如下代码实现想要的功能

import paddle; print(paddle.__version__) #2.6.0

t = paddle.zeros((2,2), dtype=paddle.int64)
print(paddle.iinfo(paddle.int64).max == 2**63-1) #True
val = paddle.to_tensor(2**63-1, paddle.int64)
t.masked_fill_(paddle.to_tensor([[0,0],[0,1]]), val) #<--------------
print(t) 
#Tensor(shape=[2, 2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
#       [[0                  , 0                  ],
#        [0                  , 9223372036854775807]])

@zyfncg 看起来是 full op 里面的 bug,如下代码会产生溢出

>>> paddle.full([], 2**63-1, paddle.int64)
Tensor(shape=[], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       -9223372036854775808)

👍

这样的话paddle.full是否可以替换为to_tensor呢?

@zyfncg 看起来是 full op 里面的 bug,如下代码会产生溢出

>>> paddle.full([], 2**63-1, paddle.int64)
Tensor(shape=[], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       -9223372036854775808)

👍

这样的话paddle.full是否可以替换为to_tensor呢?

可以替换,但是是不是还是修复 full 的这个 bug 会好一些,要不还有可能有其他 api 有类似的没发现的越界问题

類似的問題我也有碰到 也很像overflow的問題

import paddle; print(paddle.__version__) #2.6.0
x = paddle.to_tensor([[0, 0], 
                      [-2**63, 0]], dtype=paddle.int64)
print(x) 
x == 2**63-1 # __eq__
#Tensor(shape=[2, 2], dtype=bool, place=Place(gpu:0), stop_gradient=True,
#       [[False, False],
#        [True , False]])

一樣是max會被當成min

類似的問題我也有碰到 也很像overflow的問題

import paddle; print(paddle.__version__) #2.6.0
x = paddle.to_tensor([[0, 0], 
                      [-2**63, 0]], dtype=paddle.int64)
print(x) 
x == 2**63-1 # __eq__
#Tensor(shape=[2, 2], dtype=bool, place=Place(gpu:0), stop_gradient=True,
#       [[False, False],
#        [True , False]])

一樣是max會被當成min

这个应该是数值溢出导致的,也有可能是类似的原因导致的

image

进一步分析了下,通过pybind将Python数据类型转到C++时2**63-1会被识别为float类型,而2**63-12**63的浮点表示是相同的,所以在C++层转回int64类型时解析为2**63就会出现精度溢出的问题。

这个问题看起来不是很好解决,对于临界值的处理经常会有风险,如果有绕过方式建议先避开这里的逻辑。

image 进一步分析了下,通过pybind将Python数据类型转到C++时`2**63-1`会被识别为float类型,而`2**63-1`和`2**63`的浮点表示是相同的,所以在C++层转回int64类型时解析为`2**63`就会出现精度溢出的问题。

这个问题看起来不是很好解决,对于临界值的处理经常会有风险,如果有绕过方式建议先避开这里的逻辑。

这个 float 会不会是这里导致的?

image

验证了下,确实是这里导致的

还有一个问题就是在静态图模式下这里的value也只能用float类型表示,动态图的类型即使转为int64,在动转静后还是会遇到这里精度溢出的问题