RobertTLange / evosax

Evolution Strategies in JAX 🦎

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A GNN-based Meta-Learning Method for Sparse Portfolio Optimization

kayuksel opened this issue · comments

Hello Robert,

Let me start by saying that I am a fan of your work here. I have recently open-sourced by GNN-based meta-learning method for optimization. I have applied it to the sparse index-tracking problem from real-world (after an initial benchmarking on Schwefel function), and it seems to outperform Fast CMA-ES significantly both in terms of producing robust solutions on the blind test set and also in terms of time (total duration and iterations) and space complexity. I include the link to my repository here, in case you would consider adding the method or the benchmarking problem to your repository. Note: GNN, which learns how to generate populations of solutions at each iteration, is trained using gradients retrieved from the loss function, as opposed to black-box ones.

Sincerely, K

Hi @kayuksel,

Very cool work. I assume there is no preprint? I am myself very much interested in meta-learning ES/black-box optimizers (see link). But how can I/evosax help you? Do you want to implement your GNN-based BBO in JAX and open a PR?

Best wishes,
Rob

Hello again,

Thank you for your response. By the way, I am also an ex-TU-Berlin (through BCCN-Berlin, DAI-Labor, and T-Labs). I actually developed this method 2 years ago but couldn't find the time to write a publication as I am working full-time in the industry.

I first started with a game-theoretic mechanism for global optimization through the positive surprise (using a surrogate model as the critic for the generator). That turned out to be beneficial but not very critical in many of the problems that I've tried on.

Hence, I probably need to write my publication from scratch (which was extremely focused on that idea). It was also difficult for me to write many benchmark methods as you did, and run the experiments. If you like, we can collaborate on such a publication.

About the JAX implementation, the PyTorch code is actually under 100 lines. But as I have never yet developed with JAX, I am unsure that I could convert that easily myself or when I would be able to do that as I currently have some work congestion.

Sincerely, K

Here is a highly simplified version of the code. I have removed weight initialization and adaptive gradient clipping.
What was actually most beneficial was using a method called GradInit to minimize the random seed sensitivity.

def schwefel(x):
    x = x * 500
    return 418.9829 * x.shape[1] - (x * x.abs().sqrt().sin()).sum(dim=1)

class LSTMModule(nn.Module):
    def __init__(self, input_size = 1, hidden_size = 1, num_layers = 2):
        super(LSTMModule, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.h = torch.zeros(num_layers, 1, hidden_size, requires_grad=True).cuda()
        self.c = torch.zeros(num_layers, 1, hidden_size, requires_grad=True).cuda()
    def forward(self, x):
        self.rnn.flatten_parameters()
        out, (h_end, c_end) = self.rnn(x, (self.h, self.c))
        self.h.data = h_end.data
        self.c.data = c_end.data
        return out[:,-1, :].flatten()

class Extractor(nn.Module):
    def __init__(self, latent_dim, ks = 5):
        super(Extractor, self).__init__()
        self.conv = nn.Conv1d(args.noise, latent_dim,
            bias = False, kernel_size = ks, padding = (ks // 2) + 1)
        self.conv.weight.data.normal_(0, 0.01)
        self.activation = nn.Sequential(nn.BatchNorm1d(
            latent_dim, track_running_stats = False), nn.Mish())
        self.gap = nn.AvgPool1d(kernel_size = args.batch, padding = 1)
        self.rnn = LSTMModule(hidden_size = latent_dim)
    def forward(self, x):
        y = x.unsqueeze(0).permute(0, 2, 1)
        y = self.rnn(self.gap(self.activation(self.conv(y))))
        return torch.cat([x, y.repeat(args.batch, 1)], dim = 1)

class Generator(nn.Module):
    def __init__(self, noise_dim = 0):
        super(Generator, self).__init__()
        def block(in_feat, out_feat):
            return [nn.Linear(in_feat, out_feat), nn.Tanh()]
        self.model = nn.Sequential(
            *block(noise_dim+args.cnndim, 480), *block(480, 1103), nn.Linear(1103, args.funcd))
        self.extract = Extractor(args.cnndim)
        self.std_weight = nn.Parameter(torch.zeros(args.funcd).cuda())
    def forward(self, x):
        mu = self.model(self.extract(x))
        return mu + (self.std_weight * torch.randn_like(mu))

actor = Generator(args.noise).cuda()
opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, actor.parameters()), lr=1e-3)
best_reward = None
for epoch in range(args.iter):
    torch.cuda.empty_cache()
    opt.zero_grad()
    z = torch.randn((args.batch, args.noise)).cuda().requires_grad_()
    rewards =  schwefel(actor(z).tanh())
    min_index = rewards.argmin()
    if best_reward is None: best_reward = rewards[min_index]
    actor_loss = rewards.mean()
    actor_loss.backward()
    nn.utils.clip_grad_norm_(actor.parameters(), 1.0)
    opt.step()
    with torch.no_grad():
        if rewards[min_index] > best_reward: continue
        best_reward = rewards[min_index]
        print('epoch: %i loss: %f' % (epoch, best_reward.item()))

Correct me if I am wrong, but you do need to be able to calculate gradients through the objective function, right? So it is not generally applicable to black-box optimization and closer to the line of work learned gradient-based optimization by Luke Metz.

Yes, that is true. It is possible to train a surrogate function to get the gradients from it for black-box optimization. But then, the convergence was much slower - at least in the few attempts I made to overcome that issue. I am unsure which work of Luke Metz you refer to, but the main difference could be that my method is a population-based generative one like ES/QD algorithms.

Okay, I will close this issue for now, since it is a gradient-based method and there is no reference paper (yet). Finally, I don't have a trained checkpoint/details for meta-training. Please feel free to open a PR once you are ready/have a JAX implementation. Luke Metz's work can be found here.