:pencil: convert nodes_id to edges_id and convert back
Yonv1943 opened this issue · comments
仿真环境需要一个功能:
把储存了节点收缩顺序的list,从 记录两个节点收缩顺序,到记录这两个节点对应的边的收缩顺序。
见代码TNCO_env.py
中的:
先创建仿真环境这个类,选择想要转换的电路
def unit_test_convert_node2s_to_edge_sorts():
gpu_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
device = th.device(f'cuda:{gpu_id}' if th.cuda.is_available() and gpu_id >= 0 else 'cpu')
nodes_list, ban_edges = NodesSycamoreN12M14, 0
# nodes_list, ban_edges = NodesSycamoreN14M14, 0
# nodes_list, ban_edges = NodesSycamoreN53M12, 0
# nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=8), 8
# nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=100), 100
# nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=2000), 2000
# from TNCO_env import get_nodes_list_of_tensor_tree
# nodes_list, ban_edges = get_nodes_list_of_tensor_tree(depth=3), 2 ** (3 - 1)
env = TensorNetworkEnv(nodes_list=nodes_list, ban_edges=ban_edges, device=device)
print(f"\nnum_nodes {env.num_nodes:9}"
f"\nnum_edges {env.num_edges:9}"
f"\nban_edges {env.ban_edges:9}")
下面演示了把 edge_ary
转化成 node2s
转化回 edge_ary
的过程,调用了两个函数:
- edge_ary → edge_sort → node2s
node2s = env.convert_edge_sort_to_node2s(edge_sort=edge_ary.argsort(dim=0))
- node2s → edge_sort
edge_sort = env.convert_node2s_to_edge_sort(node2s=node2s).to(device)
num_envs = 6
# th.save(edge_arys, 'temp.pth')
# edge_arys = th.load('temp.pth', map_location=device)
edge_arys = th.rand((num_envs, env.num_edges - env.ban_edges), device=device)
edge_ary = edge_arys[0]
print(edge_ary.argsort().shape)
print(edge_ary.argsort())
node2s = env.convert_edge_sort_to_node2s(edge_sort=edge_ary.argsort(dim=0))
edge_sort = env.convert_node2s_to_edge_sort(node2s=node2s).to(device)
print(edge_sort.shape)
print(edge_sort)
print(edge_sort - edge_ary.argsort())
edge_sorts = edge_sort.unsqueeze(0)
multiple_times = env.get_log10_multiple_times(edge_sorts=edge_sorts)
print(f"multiple_times(log10) {multiple_times.numpy()}")
输出是:(在这个电路下,nodes_list, ban_edges = NodesSycamoreN12M14, 0
)
num_nodes 51
num_edges 99
ban_edges 0
torch.Size([99])
tensor([30, 36, 49, 35, 55, 65, 0, 28, 61, 52, 45, 69, 10, 21, 83, 18, 56, 9,
14, 70, 39, 19, 74, 43, 68, 75, 60, 81, 29, 47, 94, 24, 58, 77, 64, 15,
13, 72, 87, 32, 71, 51, 85, 6, 44, 34, 96, 40, 38, 97, 46, 53, 82, 84,
22, 90, 25, 23, 33, 92, 1, 62, 42, 91, 67, 93, 26, 98, 79, 12, 16, 27,
78, 95, 8, 11, 80, 20, 4, 57, 73, 54, 2, 7, 66, 3, 5, 88, 37, 59,
17, 48, 50, 41, 86, 89, 76, 63, 31])
torch.Size([99])
tensor([30, 36, 49, 35, 55, 65, 0, 28, 61, 52, 45, 69, 10, 21, 83, 18, 56, 9,
14, 70, 39, 19, 74, 43, 68, 75, 60, 81, 29, 47, 94, 24, 58, 77, 64, 15,
13, 72, 87, 32, 71, 51, 85, 6, 44, 34, 96, 40, 38, 97, 46, 53, 82, 84,
22, 90, 25, 23, 33, 92, 1, 62, 42, 91, 67, 93, 26, 98, 79, 12, 16, 27,
78, 95, 8, 11, 80, 20, 4, 57, 73, 54, 2, 7, 66, 3, 5, 88, 37, 59,
17, 48, 50, 41, 86, 89, 76, 63, 31], dtype=torch.int32)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0])
multiple_times(log10) [12.06995569]
上面添加了一个更新,经过校验,张潘用的节点收缩顺序,与我们不同。
按照要收缩的边两端的点的序号 47以及70 ,我们查找 节点连接的无向图,发现这两个点不相连。于是推测我们用的 点的编号不一致。(可我们的点的编号用的就是 sycamore 他们官方提供的编号 方案,而不是我们自己的编号)
经过后面的检查,我发现是因为 他们给收缩后的节点一个新的编号,而我们给收缩后的节点的编号沿用了收缩前的任意一个节点,因此在查找节点相连情况时 出错。(2023-04-14 16:03:48)
请注意:
- 我们的仿真环境兼容 两个节点收缩后编号为某个节点 的情况。例如, 节点A 与节点B 与节点C收缩后,无论它编号为 A,B或者C,我们的仿真环境都可以处理。
- 我们的仿真环境,已经使用上面的 nodes_id 与 edges_id 相互转换检验过了。
- 可能是他们把 sycamore 里的某些 量子比特挖掉了。
# node2s = Node2sSycamoreN53N20Zhang1Pan1V1 # ERROR because of their un standard node_id (from paper)
num_nodes 381
num_edges 754
ban_edges 0
47 tensor([107, 110, 66, 65], dtype=torch.int32) ;;; 70 tensor([246, 152, 112, 111], dtype=torch.int32)
47 tensor([68, 69, 26, 27], dtype=torch.int32) ;;; 70 tensor([137, 90, 28, 48], dtype=torch.int32)
所以会有以下 IndexError
line 592, in convert_node2s_to_edge_sort
edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
IndexError: list index out of range
他们论文的张量收缩方式,和他们代码的不一样,下面尝试了他们代码的收缩方式,依然有相同的错误:
# node2s = Node2sSycamoreN53N20Zhang1Pan1V2 # ERROR because of their un standard node_id (from code)
num_nodes 381
num_edges 754
ban_edges 0
47 tensor([107, 110, 66, 65], dtype=torch.int32) ;;; 7 tensor([26, 32, 2, 7], dtype=torch.int32)
47 tensor([68, 69, 26, 27], dtype=torch.int32) ;;; 7 tensor([26, 30, 2, 15], dtype=torch.int32)
下面是第三个收缩方式,依然有错:
32 tensor([120, 76, 36, -1], dtype=torch.int32) ;;; 74 tensor([376, 160, 120, 119], dtype=torch.int32)
32 tensor([74, 52, 9, -1], dtype=torch.int32) ;;; 74 tensor([202, 94, 32, 52], dtype=torch.int32)
52 tensor([119, 122, 76, 75], dtype=torch.int32) ;;; 32 tensor([120, 76, 36, -1], dtype=torch.int32)
52 tensor([74, 75, 32, 33], dtype=torch.int32) ;;; 32 tensor([74, 52, 9, -1], dtype=torch.int32)
75 tensor([159, 162, 122, 121], dtype=torch.int32) ;;; 94 tensor([208, 199, 160, 159], dtype=torch.int32)
75 tensor([94, 95, 52, 53], dtype=torch.int32) ;;; 94 tensor([118, 114, 74, 75], dtype=torch.int32)
52 tensor([119, 122, 76, 75], dtype=torch.int32) ;;; 75 tensor([159, 162, 122, 121], dtype=torch.int32)
52 tensor([74, 75, 32, 33], dtype=torch.int32) ;;; 75 tensor([94, 95, 52, 53], dtype=torch.int32)
13 tensor([38, 43, 5, 50], dtype=torch.int32) ;;; 33 tensor([75, 78, 38, 37], dtype=torch.int32)
13 tensor([33, 36, 5, 39], dtype=torch.int32) ;;; 33 tensor([52, 53, 13, 10], dtype=torch.int32)
53 tensor([121, 124, 78, 77], dtype=torch.int32) ;;; 13 tensor([38, 43, 5, 50], dtype=torch.int32)
53 tensor([75, 76, 33, 34], dtype=torch.int32) ;;; 13 tensor([33, 36, 5, 39], dtype=torch.int32)
Traceback (most recent call last):
line 669, in convert_node2s_to_edge_sort
edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
IndexError: list index out of range
I offer another version to try out
The reason for the error may be the use of open quantum bits in their method, which requires a reduction in the number of tensor from 381 to 345, possibly leading to a numbering inconsistency problem between us and them.
这里写一下,如何校验 节点收缩顺序。
下面的函数,根据 节点收缩list node2s: list
提供的收缩顺序,以及我们仿真环境自动生成的 收缩的边的编号表,得到 边的编号的收缩顺序。
def convert_node2s_to_edge_sort(self, node2s: list) -> TEN:
edges_ary: TEN = self.edges_ary.cpu()
nodes_ary: TEN = self.nodes_ary.cpu()
edge_sort = []
import numpy as np
for node_i0, node_i1 in node2s:
print(f"{node_i0:4} {str(edges_ary[node_i0].numpy()):17} "
f"{node_i1:4} {str(edges_ary[node_i1].numpy()):17} |"
f"{node_i0:4} {str(nodes_ary[node_i0].numpy()):17} "
f"{node_i1:4} {str(nodes_ary[node_i1].numpy()):17}")
edge_is = np.intersect1d(edges_ary[node_i0], edges_ary[node_i1])
edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
edge_sort.append(edge_i)
edge_sort = th.tensor(edge_sort)
return edge_sort
输入以下
nodes_list, ban_edges = NodesSycamoreN53M12, 0
Node2sSycamoreN53N20Xu3Wei3 = [
(32, 9), (190, 189), (3, 0), ...]
node2s = Node2sSycamoreN53N20Xu3Wei3
edge_sort = env.convert_node2s_to_edge_sort(node2s=node2s).to(device)
出现报错,以及log信息:
num_nodes 381
num_edges 754
ban_edges 0
32 [120 76 36 -1] 9 [29 36 44 -1] | 32 [74 52 9 -1] 9 [28 32 36 -1]
190 [436 396 352 351] 189 [350 349 568 520] | 190 [232 212 171 170] 189 [170 147 298 274]
...
edge_i = [edge_i for edge_i in edge_is if edge_i != -1][0]
IndexError: list index out of range
解读第一个张量收缩过程:
32 [120 76 36 -1] 9 [29 36 44 -1] | 32 [74 52 9 -1] 9 [28 32 36 -1]
node2s
的第一组数据是(32, 9)
希望收缩编号为32 和 9 的两个node- 程序在
edges_ary
这个node连接表上查找,可以看到,这两个node 被 编号为36 的edge连起来了:- 编号为32 的node 连接了 这些编号为
[120 76 36 -1]
的edge - 编号为9 的node 连接了 这些编号为
[29 36 44 -1]
的edge
- 编号为32 的node 连接了 这些编号为
- 程序j继续在
nodes_ary
这个edge 连接表上查找,可以看到,这两个node 的确被连起来,因为他们相互记录了对方的 node 编号:- 编号为32 的node 连接了 这些编号为
[74 52 9 -1]
的node ,我们能在里面找到 编号9 的node - 编号为9 的node 连接了 这些编号为
[28 32 36 -1]
的node ,我们能在里面找到 编号32 的node
- 编号为32 的node 连接了 这些编号为
继续解读下一个张量收缩过程,发现出错了
这不是因为用户提供的 node_arys
和 node2s
不匹配,而是因为我没有更新收缩后的张量收缩表,
190 [436 396 352 351] 189 [350 349 568 520] | 190 [232 212 171 170] 189 [170 147 298 274]
node2s
的第二组数据是(190, 189)
希望收缩编号为190 和 189 的两个node- 程序在
edges_ary
这个node连接表上查找,可以看到,这两个node 连接的nodes 取交集发现是空集,因此没有相连,无法收缩:- 编号为190 的node 连接了 这些编号为
[436 396 352 351]
的edge - 编号为189 的node 连接了 这些编号为
[350 349 568 520]
的edge
- 编号为190 的node 连接了 这些编号为
- 程序j继续在
nodes_ary
这个edge 连接表上查找,可以看到,这两个node 没有相连,因为他们记录了的 node 编号没有对方:- 编号为190 的node 连接了 这些编号为
[232 212 171 170]
的node ,我们不能在里面找到 编号189 的node - 编号为189 的node 连接了 这些编号为
[170 147 298 274]
的node ,我们不能在里面找到 编号190 的node
- 编号为190 的node 连接了 这些编号为
问题已经解决,是我的 转换代码有问题。
- 先前是因为别人的张量收缩顺序,给收缩后的张量一个新的
node_id
,所以出错。改成顶替掉收缩前的任意一个node_id
就好 - 后来出错,是因为张量
node_i0
和node_i1
收缩后,应该给哪个张量的node_id
呢? 我的代码没有做这个适配,因此出错 - 最后改成:只要收缩后的
node_id
顶替掉原本收缩前的任意一个node_id
就好
Node2sSycamoreN53N20Zhang1Pan1V4 = [
(32, 9), (190, 189), (3, 0), (6, 1), (13, 5), (169, 127), (127, 146), (192, 152), (193, 153), (196, 113),
(197, 175), (201, 74), (205, 160), (208, 164), (210, 167), (7, 2), (12, 4), (16, 10), (17, 11), (18, 14), (23, 8),
(35, 15), (42, 20), (20, 62), (43, 22), (45, 24), (24, 66), (46, 25), (47, 26), (48, 28), (28, 70), (49, 29),
(50, 30), (51, 27), (52, 33), (53, 34), (56, 36), (36, 78), (57, 37), (58, 38), (59, 39), (39, 81), (60, 40),
(61, 41), (41, 83), (41, 103), (63, 21), (64, 44), (76, 54), (77, 55), (55, 97), (86, 65), (87, 67), (88, 68),
(89, 69), (90, 71), (91, 72), (92, 73), (94, 75), (98, 79), (99, 80), (101, 82), (104, 19), (105, 84), (107, 85),
(115, 95), (116, 96), (117, 93), (121, 100), (124, 102), (130, 108), (131, 111), (132, 106), (106, 149), (133, 114),
(136, 109), (137, 118), (138, 110), (139, 120), (140, 112), (141, 122), (142, 123), (144, 119), (145, 126),
(126, 168), (148, 128), (128, 170), (150, 129), (129, 171), (151, 135), (157, 134), (158, 143), (172, 147),
(173, 154), (174, 155), (178, 159), (180, 161), (181, 162), (182, 163), (183, 156), (184, 165), (185, 166),
(194, 176), (195, 177), (199, 198), (203, 202), (206, 186), (207, 187), (209, 188), (109, 152), (118, 160),
(118, 179), (110, 153), (122, 164), (167, 125), (21, 20), (44, 22), (54, 34), (67, 24), (71, 28), (79, 36),
(82, 39), (135, 108), (155, 106), (166, 143), (30, 31), (147, 191), (159, 200), (202, 204), (159, 113), (110, 161),
(122, 156), (125, 119), (188, 127), (20, 84), (22, 65), (24, 25), (28, 29), (36, 37), (39, 40), (106, 112),
(143, 102), (147, 189), (118, 175), (113, 93), (122, 114), (127, 165), (202, 74), (93, 109), (22, 8), (24, 68),
(28, 72), (36, 80), (106, 163), (147, 128), (34, 11), (30, 2), (127, 123), (74, 9), (20, 0), (119, 41), (93, 69),
(8, 85), (24, 26), (36, 38), (106, 100), (128, 19), (102, 96), (123, 134), (9, 33), (41, 126), (114, 118),
(19, 108), (69, 27), (2, 4), (123, 95), (9, 75), (8, 1), (19, 129), (27, 73), (2, 15), (100, 55), (41, 10),
(19, 154), (9, 5), (95, 14), (19, 111), (55, 177), (19, 120), (5, 162), (14, 176), (5, 198), (5, 187), (5, 186),
(5, 114), (5, 110), (5, 28), (5, 10), (5, 14), (5, 36), (5, 39), (5, 11), (5, 96), (5, 19), (5, 55), (5, 27),
(5, 2), (2, 24), (2, 1), (1, 0)
]
对于上面的输入,它有以下输出 multiple_times(log10) [17.49347767]
,验证完成
num_nodes 211
num_edges 414
ban_edges 0
torch.Size([210])
tensor([ 36, 354, 0, 1, 5, 309, 310, 358, 361, 369, 372, 385, 397, 406,
413, 2, 4, 8, 9, 10, 19, 42, 56, 96, 57, 62, 104, 64,
66, 68, 112, 70, 72, 73, 75, 77, 84, 128, 86, 88, 90, 133,
92, 94, 137, 177, 97, 99, 123, 125, 165, 143, 145, 147, 149, 151,
153, 155, 159, 167, 169, 173, 180, 182, 186, 201, 203, 206, 213, 219,
231, 233, 236, 269, 237, 244, 245, 248, 249, 252, 253, 255, 260, 261,
307, 268, 312, 272, 314, 273, 286, 287, 317, 318, 320, 328, 332, 334,
336, 339, 340, 342, 364, 367, 379, 391, 399, 403, 409, 275, 291, 330,
277, 299, 305, 98, 100, 80, 106, 152, 168, 136, 274, 321, 304, 71,
355, 381, 395, 370, 333, 338, 306, 348, 139, 144, 105, 113, 129, 89,
281, 257, 353, 373, 289, 284, 349, 388, 329, 59, 148, 154, 170, 337,
350, 39, 32, 264, 120, 13, 221, 189, 141, 107, 131, 297, 266, 258,
285, 76, 411, 322, 232, 109, 35, 256, 160, 58, 357, 117, 41, 214,
54, 356, 38, 225, 280, 366, 250, 389, 324, 378, 347, 404, 384, 345,
200, 37, 162, 43, 49, 161, 224, 288, 352, 195, 33, 192, 34, 3])
Not Standard edge_sorts
multiple_times(log10) [17.49347767]
经过检查,我们发现了第二个错误:
代码路径是 ...\Python39\Lib\site-packages\opt_einsum\contract.py
for cnum, contract_inds in enumerate(path):
# Make sure we remove inds from right to left
contract_inds = tuple(sorted(list(contract_inds), reverse=True))
contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
# Compute cost, scale, and size
cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict)
cost_list.append(cost)
print(cost) # -----------------------------------------------------------------> 打印出每一次张量收缩产生的乘法次数
scale_list.append(len(idx_contract))
size_list.append(helpers.compute_size_by_dict(out_inds, size_dict))
tmp_inputs = [input_list.pop(x) for x in contract_inds]
tmp_shapes = [input_shps.pop(x) for x in contract_inds]
如上面代码所示,我看了他们 opt_einsum 的源代码,把他们每一次收缩的乘法次数print出来,整理到 excel 表格里
蓝色的是他们的正确结果,红色的是我们的,每一次的计算都相差 2**1
我们初始化的时候,给每个节点它自己多送了一个 2 ** 1 标记,然后在计算重复的乘法时,我们要减去 被 bool_ary 标记出来的 节点,然后多减了它自己。
因此,一行代码就能修复
'''calculate the multiple and avoid repeat'''
ct_dimss = dims_tens[env_is, node_i0s] + dims_tens[env_is, node_i1s] * if_diffs.unsqueeze(1)
ct_bools = bool_tens[env_is, node_i0s] | bool_tens[env_is, node_i1s]
# assert ct_dimss.shape == (num_envs, num_nodes)
# assert ct_bools.shape == (num_envs, num_nodes)
# 初始化的时候,给每个节点它自己多送了一个 2 ** 1 标记,排除重复的乘法时,会多减了它自己,下面的代码把它加回去
pow_times = ct_dimss.sum(dim=1) - (ct_dimss * ct_bools).sum(dim=1) * 0.5 + 1 # --------> 多了一个加一,修复了bug
pow_timess[:, i] = pow_times * if_diffs
跑出来的结果是:
node2s = Node2sSycamoreN53N20Test1
# log10(multiple_times) =
他们的结果 25.6106868931126
我们的结果 25.6106813 # power_max - 960
我们的结果 25.61068416 # power_max - 512
用Python自带的 int 大数计算得到的乘法次数
真实的结果 40802511241875888470868352
真实的结果 25.6106868931126
我们的结果当前能精确到有效数字7位,我可以继续改进,这就和这个 issue 无关了。
精度和下面的代码有关,下面的代码为了用 float64 计算更大的数值,减去了 (power_max - 960)
,但是损失了精度。
# 计算这个乘法个数时,即便用 float64 也偶尔会过拟合,所以先除以 2**temp_power ,求log10 后再恢复它
adj_pow_times = pow_timess.max(dim=1)[0] - 960 # automatically set `max - 960`, 960 < the limit 1024,
multiple_times = (2 ** (pow_timess - adj_pow_times.unsqueeze(1))).sum(dim=1)
multiple_times = multiple_times.log10() + adj_pow_times / th.log2(th.tensor((10,), device=device))
# adj_pow_times / th.log2(th.tensor((10, ), device=device)) # Change of Base Formula
以下提供一种缓慢但是完全不损失精度的计算方法
结果是:
multiple_times(log10) [25.61068689]
diff 0.000e+00
代码是:
# 缓慢但是完全不损失精度的计算方法
multiple_times = []
pow_timess = pow_timess.cpu().numpy()
for env_id in range(num_envs):
multiple_time = 0
for pow_time in pow_timess[env_id, :]:
multiple_time = multiple_time + 2 ** pow_time
multiple_time = math_log10(multiple_time)
multiple_times.append(multiple_time)
multiple_times = th.tensor(multiple_times, dtype=th.float64)
重要补充:
我认为 opt_einsum 的计算出错了,而不是我们出错。下面是之前将我们 TensorNetworkEnv 与 opt_sinsum 的结果进行的对比:
经过检查,我们发现了第二个错误:
代码路径是
...\Python39\Lib\site-packages\opt_einsum\contract.py
for cnum, contract_inds in enumerate(path): # Make sure we remove inds from right to left contract_inds = tuple(sorted(list(contract_inds), reverse=True)) contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set) out_inds, input_sets, idx_removed, idx_contract = contract_tuple # Compute cost, scale, and size cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict) cost_list.append(cost) print(cost) # -----------------------------------------------------------------> 打印出每一次张量收缩产生的乘法次数 scale_list.append(len(idx_contract)) size_list.append(helpers.compute_size_by_dict(out_inds, size_dict)) tmp_inputs = [input_list.pop(x) for x in contract_inds] tmp_shapes = [input_shps.pop(x) for x in contract_inds]
如上面代码所示,我看了他们 opt_einsum 的源代码,把他们每一次收缩的乘法次数print出来,整理到 excel 表格里
蓝色的是他们的正确结果,红色的是我们的,每一次的计算都相差 2**1
下面讨论的 +1,其实是最终的乘法次数 + log10(2**1)
- @spicywei 自己手动计算了 TensorTrain (6个节点,不包含虚拟节点)的结果,发现不需要 +1
- 我们查看了 sycamore电路的前几个收缩产生的乘法次数,发现不需要 +1,而是opt_einsum 出错了。
检查过程:
- 在 Node2sSycamoreN53N20 的电路里,根据张量收缩顺序
Node2sSycamoreN53N20Test2 = [(32, 9), (360, 359), ...]
,我们先收缩 编号 32 和 编号 9 的这两个节点 - 编号为32 的node 连接了 这些编号为
[74 52 9 -1]
的node ,我们能在里面找到 编号9 的node,编号为9 的node 连接了 这些编号为[28 32 36 -1]
的node ,我们能在里面找到 编号32 的node - 计算的乘法次数如下:编号32外接的边是 3条,编号9外接的边是3条,这两个节点中间有一条要收缩,乘法次数为
2 ** (3 + 3 - 1) == 2 ** 5
- 如下面表格截图所示, opt_einsum 算出来是
2 ** 6
, 而我们算出来是2 ** 5
,我们是对的,因此不需要给乘法次数加上log10(2**1)
如下方截图所示:收缩 编号 32 和 编号 9 的这两个节点 产生的乘法次数被记录在第一行,其中左列是 opt_einsum 的结果,右列是 我们TensorNetworkEnv 的结果。
我在本地对上述过程进行了二次核验,确实存在opt_einsum计算错误的问题