🐛 contract in some orders on a circuit with ring struction may get incorrect multiplication counts.
Yonv1943 opened this issue · comments
收缩 sycamore 以及 tensor grid ,tensor ring 这种有环状结构的电路,会有bug,导致乘法次数计算错误
(刚好我们测试的 tensor train,tensor tree 没有环状结构)
需要有环状结构,且按某个顺序收缩张量节点,才会触发
下面的代码,在一个小规模的 sycamore 电路 NodesSycamoreN12M14 上得到,然后逐行检查发现了这个bug
num_nodes 51
num_edges 99
ban_edges 0
先粗略记录一下。
这是print代码
'''calculate the multiple and avoid repeat'''
contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1] # 计算收缩后的node 的邻接张量的维度 以及来源
contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1] # 计算收缩后的node 由哪些原初node 合成
# assert contract_dims.shape == (num_nodes, )
# assert contract_bool.shape == (num_nodes, )
print(';;;', i, node_i0, node_i1)
print(node_dims_arys[node_i0].numpy().astype(int))
print(node_dims_arys[node_i1].numpy().astype(int))
print(contract_dims.numpy().astype(int))
print(contract_bool.numpy().astype(int))
这是print内容。可以看到,对已经收缩的节点竟然进行了不可能的收缩,并且产生了多余的乘法次数。
;;; 52 tensor(3) tensor(9)
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 352 0 0 0 0 96 0 0 0 128 0 0 0 0 320 0 128 0 0 0 64 0 64 64 0 0 0 0 192 0 64 64 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 352 0 0 0 0 96 0 0 0 128 0 0 0 0 320 0 128 0 0 0 64 0 64 64 0 0 0 0 192 0 64 64 0]
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 704 0 0 0 0 192 0 0 0 256 0 0 0 0 640 0 256 0 0 0 128 0 128 128 0 0 0 0 384 0 128 128 0]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 0 1 1 1 0 1 1 1 1 0 1 0 1 0 1 0 1 0 0 1 1 0 0 0 1 0 0 1]
The following code fix this bug in Vanilla (single) mode.
The if_diff
fix this bug.
First, choose a reasonable data type. It is safer that use float32 instead of int/long for node_dims_tens
.
node_dims_tens = th.stack([self.node_dims_ten.clone() for _ in range(num_envs)]).type(th.float32)
node_bool_tens = th.stack([self.node_bool_ten.clone() for _ in range(num_envs)]).type(th.bool)
"""Vanilla (single)"""
for j in range(num_envs):
edge_i = edge_is[j]
node_dims_arys = node_dims_tens[j]
node_bool_arys = node_bool_tens[j]
'''find two nodes of an edge_i'''
node_i0, node_i1 = th.where(edges_ary == edge_i)[0] # 找出这条edge 两端的node
# assert isinstance(node_i0.item(), int)
# assert isinstance(node_i1.item(), int)
'''whether node_i0 and node_i1 are different'''
if_diff = th.logical_not(node_bool_arys[node_i0, node_i1])
'''calculate the multiple and avoid repeat'''
contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1] * if_diff # 计算收缩后的node 的邻接张量的维数以及来源
contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1] # 计算收缩后的node 由哪些原初node 合成
# assert contract_dims.shape == (num_nodes, )
# assert contract_bool.shape == (num_nodes, )
# 收缩掉的edge 只需要算一遍乘法。因此上面对 两次重复的指数求和后乘以0.5
mult_pow_time = contract_dims.sum(dim=0) - (contract_dims * contract_bool).sum(dim=0) * 0.5
mult_pow_timess[j, i] = mult_pow_time * if_diff
'''adjust two list: node_dims_arys, node_bool_arys'''
# 如果两个张量是一样的,那么 `contract_bool & if_diff` 就会全部变成 False,让下面这行代码不修改任何数值
contract_dims[contract_bool & if_diff] = 0 # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算
node_dims_tens[j, contract_bool] = contract_dims.repeat(1, 1) # 根据 bool 将所有收缩后的节点都刷新成相同的信息
node_bool_tens[j, contract_bool] = contract_bool.repeat(1, 1) # 根据 bool 将所有收缩后的节点都刷新成相同的信息
Vectorized version
"""Vectorized"""
'''find two nodes of an edge_i'''
vec_edges_ary: TEN = edges_ary[None, :, :]
vec_edges_is: TEN = edge_is[:, None, None]
res = th.where(vec_edges_ary == vec_edges_is)[1]
res = res.reshape((num_envs, 2))
node_i0s, node_i1s = res[:, 0], res[:, 1]
# assert node_i0s.shape == (num_envs, )
# assert node_i1s.shape == (num_envs, )
'''whether node_i0 and node_i1 are different'''
if_diffs = th.logical_not(node_bool_tens[vec_env_is, node_i0s, node_i1s])
'''calculate the multiple and avoid repeat'''
contract_dimss = node_dims_tens[vec_env_is, node_i0s] + node_dims_tens[
vec_env_is, node_i1s] * if_diffs.unsqueeze(1)
contract_bools = node_bool_tens[vec_env_is, node_i0s] | node_bool_tens[vec_env_is, node_i1s]
# assert contract_dimss.shape == (num_envs, num_nodes)
# assert contract_bools.shape == (num_envs, num_nodes)
mult_pow_times = contract_dimss.sum(dim=1) - (contract_dimss * contract_bools).sum(dim=1) * 0.5
# assert mult_pow_times.shape == (num_envs, )
mult_pow_timess[:, i] = mult_pow_times * if_diffs
'''adjust two list: node_dims_arys, node_bool_arys'''
for j in range(num_envs): # 根据 bool 将所有收缩后的节点都刷新成相同的信息
contract_dims = contract_dimss[j]
contract_bool = contract_bools[j]
contract_dims[contract_bool & if_diffs[j]] = 0 # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算
node_dims_tens[j, contract_bool] = contract_dims.repeat(1, 1)
node_bool_tens[j, contract_bool] = contract_bool.repeat(1, 1)
修改后,我还使用换底公式 # max_tmp_power / th.log2(th.tensor((10, ), device=device)) # Change of Base Formula
优化了修复这个bug之后带来的 溢出问题
# 计算这个乘法个数时,即便用 float64 也偶尔会过拟合,所以先除以 2**temp_power ,求log10 后再恢复它
max_tmp_power = mult_pow_timess.max(dim=1)[0] - 960 # automatically set `max - 960`, 960 < the limit 1024,
multiple_times = (2 ** (mult_pow_timess - max_tmp_power.unsqueeze(1))).sum(dim=1)
multiple_times = multiple_times.log10() + max_tmp_power / th.log2(th.tensor((10,), device=device))
# max_tmp_power / th.log2(th.tensor((10, ), device=device)) # Change of Base Formula