📝 update GraphMaxCutEnv to VecEnv
Yonv1943 opened this issue · comments
YonV1943 曾伊言 commented
import torch as th
import numpy as np
from torch import Tensor
class GraphMaxCutEnv:
def __init__(self, num_envs=8, device=th.device('cpu')):
txt_path = "./graph_set_G14.txt"
with open(txt_path, 'r') as file:
lines = file.readlines()
lines = [[int(i1) for i1 in i0.split()] for i0 in lines]
num_nodes, num_edges = lines[0]
edge_to_n0_n1_dist = [(i[0] - 1, i[1] - 1, i[2]) for i in lines[1:]]
'''
n0: index of node0
n1: index of node1
dt: distance between node0 and node1
p0: the probability of node0 is in set, (1-p0): node0 is in another set
p1: the probability of node0 is in set, (1-p1): node0 is in another set
'''
n0_to_n1s = [[] for _ in range(num_nodes)] # 将 node0_id 映射到 node1_id
n0_to_dts = [[] for _ in range(num_nodes)] # 将 mode0_id 映射到 node1_id 与 node0_id 的距离
for n0, n1, dist in edge_to_n0_n1_dist:
n0_to_n1s[n0].append(n1)
n0_to_dts[n0].append(dist)
n0_to_n1s = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_n1s]
n0_to_dts = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_dts] # dists == 1
assert num_nodes == len(n0_to_n1s)
assert num_nodes == len(n0_to_dts)
assert num_edges == sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s])
assert num_edges == sum([len(n0_to_dt) for n0_to_dt in n0_to_dts])
self.num_envs = num_envs
self.num_nodes = len(n0_to_n1s)
self.num_edges = sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s])
self.n0_to_n1s = n0_to_n1s
self.device = device
'''为了高性能计算,删掉了 n0_to_n1s 的空item'''
v2_ids = [i for i, n1 in enumerate(n0_to_n1s) if n1.shape[0] > 0]
self.v2_ids = v2_ids
self.v2_n0_to_n1s = [n0_to_n1s[idx] for idx in v2_ids]
self.v2_num_nodes = len(v2_ids)
def get_objective(self, p0s):
assert p0s.shape == (self.num_envs, self.num_nodes)
sum_dts = []
for env_i in range(self.num_envs):
p0 = p0s[env_i]
n0_to_p1 = []
for n1 in self.n0_to_n1s:
p1 = p0[n1]
n0_to_p1.append(p1)
sum_dt = []
for _p0, _p1 in zip(p0, n0_to_p1):
# dt = _p0 * (1-_p1) + _p1 * (1-_p0) # 等价于以下一行代码
dt = _p0 + _p1 - 2 * _p0 * _p1
sum_dt.append(dt.sum(dim=0))
sum_dt = th.stack(sum_dt).sum(dim=-1)
sum_dts.append(sum_dt)
sum_dts = th.hstack(sum_dts)
return sum_dts
def get_objectives_v1(self, p0s): # version 1
device = p0s.device
env_is = th.arange(self.num_envs, device=device)
num_envs = self.num_envs
num_nodes = self.num_nodes
n0s_to_p1 = []
for n1 in self.n0_to_n1s:
num_n1 = n1.shape[0]
if num_n1 == 0: # 为了高性能计算,可将 n0_to_n1s 的空item 删掉
p1s = th.zeros((num_envs, 0), dtype=th.float32, device=device)
else:
env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1)
n1s = n1.repeat(num_envs)
p1s = p0s[env_js, n1s].reshape(num_envs, num_n1)
n0s_to_p1.append(p1s)
sum_dts = th.zeros((num_envs, num_nodes), dtype=th.float32, device=device)
for node_i in range(num_nodes):
_p0 = p0s[:, node_i].unsqueeze(1)
_p1 = n0s_to_p1[node_i]
dt = _p0 + _p1 - 2 * _p0 * _p1
sum_dts[:, node_i] = dt.sum(dim=-1)
return sum_dts.sum(dim=-1)
def get_objectives(self, p0s): # version 2
device = p0s.device
env_is = th.arange(self.num_envs, device=device)
num_envs = self.num_envs
# num_nodes = self.num_nodes
v2_num_nodes = len(self.v2_ids)
v2_p0s = p0s[:, self.v2_ids]
n0s_to_p1 = []
for n1 in self.v2_n0_to_n1s:
num_n1 = n1.shape[0]
env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1)
n1s = n1.repeat(num_envs)
p1s = p0s[env_js, n1s].reshape(num_envs, num_n1)
n0s_to_p1.append(p1s)
sum_dts = th.zeros((num_envs, v2_num_nodes), dtype=th.float32, device=device)
for node_i in range(v2_num_nodes):
_p0 = v2_p0s[:, node_i].unsqueeze(1)
_p1 = n0s_to_p1[node_i]
dt = _p0 + _p1 - 2 * _p0 * _p1
sum_dts[:, node_i] = dt.sum(dim=-1)
return sum_dts.sum(dim=-1)
def get_rand_p0s(self):
device = self.device
return th.rand((self.num_envs, self.num_nodes), dtype=th.float32, device=device)
def check_env():
th.manual_seed(0)
env = GraphMaxCutEnv(num_envs=6)
p0s = env.get_rand_p0s()
print(env.get_objective(p0s))
print(env.get_objectives_v1(p0s))
print(env.get_objectives(p0s))
check_env()
YonV1943 曾伊言 commented
检查结果如下:
- 第一行,无矢量并行的环境,在for循环里求出的8个结果
- 第二行,矢量并行的环境,version1,求出的结果无误
- 第三行,矢量并行的环境,version2,跳过部分节点数量为0的计算,结果相差在 1e-3 以内,可以接受
tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047])
tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047])
tensor([2345.0999, 2354.1794, 2337.8171, 2338.0452, 2332.6570, 2356.8047])