AI4Finance-Foundation / RLSolver

Solvers for NP-hard and NP-complete problems with an emphasis on high-performance GPU computing.

Home Page:https://ai4finance.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

📝 update GraphMaxCutEnv to VecEnv

Yonv1943 opened this issue · comments

import torch as th
import numpy as np
from torch import Tensor


class GraphMaxCutEnv:
    def __init__(self, num_envs=8, device=th.device('cpu')):
        txt_path = "./graph_set_G14.txt"

        with open(txt_path, 'r') as file:
            lines = file.readlines()
            lines = [[int(i1) for i1 in i0.split()] for i0 in lines]

        num_nodes, num_edges = lines[0]
        edge_to_n0_n1_dist = [(i[0] - 1, i[1] - 1, i[2]) for i in lines[1:]]

        '''
        n0: index of node0
        n1: index of node1
        dt: distance between node0 and node1
        p0: the probability of node0 is in set, (1-p0): node0 is in another set
        p1: the probability of node0 is in set, (1-p1): node0 is in another set
        '''

        n0_to_n1s = [[] for _ in range(num_nodes)]  # 将 node0_id 映射到 node1_id
        n0_to_dts = [[] for _ in range(num_nodes)]  # 将 mode0_id 映射到 node1_id 与 node0_id 的距离
        for n0, n1, dist in edge_to_n0_n1_dist:
            n0_to_n1s[n0].append(n1)
            n0_to_dts[n0].append(dist)
        n0_to_n1s = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_n1s]
        n0_to_dts = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_dts]  # dists == 1
        assert num_nodes == len(n0_to_n1s)
        assert num_nodes == len(n0_to_dts)
        assert num_edges == sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s])
        assert num_edges == sum([len(n0_to_dt) for n0_to_dt in n0_to_dts])

        self.num_envs = num_envs
        self.num_nodes = len(n0_to_n1s)
        self.num_edges = sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s])
        self.n0_to_n1s = n0_to_n1s
        self.device = device

        '''为了高性能计算,删掉了 n0_to_n1s 的空item'''
        v2_ids = [i for i, n1 in enumerate(n0_to_n1s) if n1.shape[0] > 0]
        self.v2_ids = v2_ids
        self.v2_n0_to_n1s = [n0_to_n1s[idx] for idx in v2_ids]
        self.v2_num_nodes = len(v2_ids)

    def get_objective(self, p0s):
        assert p0s.shape == (self.num_envs, self.num_nodes)

        sum_dts = []
        for env_i in range(self.num_envs):
            p0 = p0s[env_i]
            n0_to_p1 = []
            for n1 in self.n0_to_n1s:
                p1 = p0[n1]
                n0_to_p1.append(p1)

            sum_dt = []
            for _p0, _p1 in zip(p0, n0_to_p1):
                # dt = _p0 * (1-_p1) + _p1 * (1-_p0)  # 等价于以下一行代码
                dt = _p0 + _p1 - 2 * _p0 * _p1
                sum_dt.append(dt.sum(dim=0))
            sum_dt = th.stack(sum_dt).sum(dim=-1)
            sum_dts.append(sum_dt)
        sum_dts = th.hstack(sum_dts)
        return sum_dts

    def get_objectives_v1(self, p0s):  # version 1
        device = p0s.device
        env_is = th.arange(self.num_envs, device=device)
        num_envs = self.num_envs
        num_nodes = self.num_nodes

        n0s_to_p1 = []
        for n1 in self.n0_to_n1s:
            num_n1 = n1.shape[0]
            if num_n1 == 0:  # 为了高性能计算,可将 n0_to_n1s 的空item 删掉
                p1s = th.zeros((num_envs, 0), dtype=th.float32, device=device)
            else:
                env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1)
                n1s = n1.repeat(num_envs)
                p1s = p0s[env_js, n1s].reshape(num_envs, num_n1)
            n0s_to_p1.append(p1s)

        sum_dts = th.zeros((num_envs, num_nodes), dtype=th.float32, device=device)
        for node_i in range(num_nodes):
            _p0 = p0s[:, node_i].unsqueeze(1)
            _p1 = n0s_to_p1[node_i]

            dt = _p0 + _p1 - 2 * _p0 * _p1
            sum_dts[:, node_i] = dt.sum(dim=-1)
        return sum_dts.sum(dim=-1)

    def get_objectives(self, p0s):  # version 2
        device = p0s.device
        env_is = th.arange(self.num_envs, device=device)
        num_envs = self.num_envs
        # num_nodes = self.num_nodes
        v2_num_nodes = len(self.v2_ids)

        v2_p0s = p0s[:, self.v2_ids]

        n0s_to_p1 = []
        for n1 in self.v2_n0_to_n1s:
            num_n1 = n1.shape[0]
            env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1)
            n1s = n1.repeat(num_envs)
            p1s = p0s[env_js, n1s].reshape(num_envs, num_n1)
            n0s_to_p1.append(p1s)

        sum_dts = th.zeros((num_envs, v2_num_nodes), dtype=th.float32, device=device)
        for node_i in range(v2_num_nodes):
            _p0 = v2_p0s[:, node_i].unsqueeze(1)
            _p1 = n0s_to_p1[node_i]

            dt = _p0 + _p1 - 2 * _p0 * _p1
            sum_dts[:, node_i] = dt.sum(dim=-1)
        return sum_dts.sum(dim=-1)

    def get_rand_p0s(self):
        device = self.device
        return th.rand((self.num_envs, self.num_nodes), dtype=th.float32, device=device)


def check_env():
    th.manual_seed(0)
    env = GraphMaxCutEnv(num_envs=6)

    p0s = env.get_rand_p0s()
    print(env.get_objective(p0s))
    print(env.get_objectives_v1(p0s))
    print(env.get_objectives(p0s))


check_env()

检查结果如下:

  • 第一行,无矢量并行的环境,在for循环里求出的8个结果
  • 第二行,矢量并行的环境,version1,求出的结果无误
  • 第三行,矢量并行的环境,version2,跳过部分节点数量为0的计算,结果相差在 1e-3 以内,可以接受
tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047])
tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047])
tensor([2345.0999, 2354.1794, 2337.8171, 2338.0452, 2332.6570, 2356.8047])