AI4Finance-Foundation / RLSolver

Solvers for NP-hard and NP-complete problems with an emphasis on high-performance GPU computing.

Home Page:https://ai4finance.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

🐛 contract in some orders on a circuit with ring struction may get incorrect multiplication counts.

Yonv1943 opened this issue · comments

收缩 sycamore 以及 tensor grid ,tensor ring 这种有环状结构的电路,会有bug,导致乘法次数计算错误
(刚好我们测试的 tensor train,tensor tree 没有环状结构)

需要有环状结构,且按某个顺序收缩张量节点,才会触发

下面的代码,在一个小规模的 sycamore 电路 NodesSycamoreN12M14 上得到,然后逐行检查发现了这个bug

num_nodes             51
num_edges             99
ban_edges              0

先粗略记录一下。

这是print代码

'''calculate the multiple and avoid repeat'''
contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1]  # 计算收缩后的node 的邻接张量的维度 以及来源
contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1]  # 计算收缩后的node 由哪些原初node 合成
# assert contract_dims.shape == (num_nodes, )
# assert contract_bool.shape == (num_nodes, )

print(';;;', i, node_i0, node_i1)
print(node_dims_arys[node_i0].numpy().astype(int))
print(node_dims_arys[node_i1].numpy().astype(int))
print(contract_dims.numpy().astype(int))
print(contract_bool.numpy().astype(int))

这是print内容。可以看到,对已经收缩的节点竟然进行了不可能的收缩,并且产生了多余的乘法次数。

;;; 52 tensor(3) tensor(9)
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 352   0   0   0   0  96   0   0   0 128   0   0   0   0 320   0 128   0   0   0  64   0  64  64   0   0   0   0 192   0  64  64   0]
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 352   0   0   0   0  96   0   0   0 128   0   0   0   0 320   0 128   0   0   0  64   0  64  64   0   0   0   0 192   0  64  64   0]
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 704   0   0   0   0 192   0   0   0 256   0   0   0   0 640   0 256   0   0   0 128   0 128 128   0   0   0   0 384   0 128 128   0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 0 1 1 1 0 1 1 1 1 0 1 0 1 0 1 0 1 0 0 1 1 0 0 0 1 0 0 1]

The following code fix this bug in Vanilla (single) mode.

The if_diff fix this bug.

First, choose a reasonable data type. It is safer that use float32 instead of int/long for node_dims_tens .

        node_dims_tens = th.stack([self.node_dims_ten.clone() for _ in range(num_envs)]).type(th.float32)
        node_bool_tens = th.stack([self.node_bool_ten.clone() for _ in range(num_envs)]).type(th.bool)
"""Vanilla (single)"""
for j in range(num_envs):
    edge_i = edge_is[j]
    node_dims_arys = node_dims_tens[j]
    node_bool_arys = node_bool_tens[j]

    '''find two nodes of an edge_i'''
    node_i0, node_i1 = th.where(edges_ary == edge_i)[0]  # 找出这条edge 两端的node
    # assert isinstance(node_i0.item(), int)
    # assert isinstance(node_i1.item(), int)

    '''whether node_i0 and node_i1 are different'''
    if_diff = th.logical_not(node_bool_arys[node_i0, node_i1])

    '''calculate the multiple and avoid repeat'''
    contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1] * if_diff  # 计算收缩后的node 的邻接张量的维数以及来源
    contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1]  # 计算收缩后的node 由哪些原初node 合成
    # assert contract_dims.shape == (num_nodes, )
    # assert contract_bool.shape == (num_nodes, )

    # 收缩掉的edge 只需要算一遍乘法。因此上面对 两次重复的指数求和后乘以0.5
    mult_pow_time = contract_dims.sum(dim=0) - (contract_dims * contract_bool).sum(dim=0) * 0.5
    mult_pow_timess[j, i] = mult_pow_time * if_diff

    '''adjust two list: node_dims_arys, node_bool_arys'''
    # 如果两个张量是一样的,那么 `contract_bool & if_diff` 就会全部变成 False,让下面这行代码不修改任何数值
    contract_dims[contract_bool & if_diff] = 0  # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算
    node_dims_tens[j, contract_bool] = contract_dims.repeat(1, 1)  # 根据 bool 将所有收缩后的节点都刷新成相同的信息
    node_bool_tens[j, contract_bool] = contract_bool.repeat(1, 1)  # 根据 bool 将所有收缩后的节点都刷新成相同的信息

Vectorized version

"""Vectorized"""
'''find two nodes of an edge_i'''
vec_edges_ary: TEN = edges_ary[None, :, :]
vec_edges_is: TEN = edge_is[:, None, None]
res = th.where(vec_edges_ary == vec_edges_is)[1]
res = res.reshape((num_envs, 2))
node_i0s, node_i1s = res[:, 0], res[:, 1]
# assert node_i0s.shape == (num_envs, )
# assert node_i1s.shape == (num_envs, )

'''whether node_i0 and node_i1 are different'''
if_diffs = th.logical_not(node_bool_tens[vec_env_is, node_i0s, node_i1s])

'''calculate the multiple and avoid repeat'''
contract_dimss = node_dims_tens[vec_env_is, node_i0s] + node_dims_tens[
    vec_env_is, node_i1s] * if_diffs.unsqueeze(1)
contract_bools = node_bool_tens[vec_env_is, node_i0s] | node_bool_tens[vec_env_is, node_i1s]
# assert contract_dimss.shape == (num_envs, num_nodes)
# assert contract_bools.shape == (num_envs, num_nodes)

mult_pow_times = contract_dimss.sum(dim=1) - (contract_dimss * contract_bools).sum(dim=1) * 0.5
# assert mult_pow_times.shape == (num_envs, )
mult_pow_timess[:, i] = mult_pow_times * if_diffs

'''adjust two list: node_dims_arys, node_bool_arys'''
for j in range(num_envs):  # 根据 bool 将所有收缩后的节点都刷新成相同的信息
    contract_dims = contract_dimss[j]
    contract_bool = contract_bools[j]

    contract_dims[contract_bool & if_diffs[j]] = 0  # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算
    node_dims_tens[j, contract_bool] = contract_dims.repeat(1, 1)
    node_bool_tens[j, contract_bool] = contract_bool.repeat(1, 1)

修改后,我还使用换底公式 # max_tmp_power / th.log2(th.tensor((10, ), device=device)) # Change of Base Formula 优化了修复这个bug之后带来的 溢出问题

        # 计算这个乘法个数时,即便用 float64 也偶尔会过拟合,所以先除以 2**temp_power ,求log10 后再恢复它
        max_tmp_power = mult_pow_timess.max(dim=1)[0] - 960  # automatically set `max - 960`, 960 < the limit 1024,
        multiple_times = (2 ** (mult_pow_timess - max_tmp_power.unsqueeze(1))).sum(dim=1)
        multiple_times = multiple_times.log10() + max_tmp_power / th.log2(th.tensor((10,), device=device))
        # max_tmp_power / th.log2(th.tensor((10, ), device=device))  # Change of Base Formula