qiangge1987 / blackbox-backprop

Torch modules that wrap blackbox combinatorial solvers according to the method presented in "Differentiating Blackbox Combinatorial Solvers"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Differentiation of Blackbox Combinatorial Solvers

By Marin Vlastelica*, Anselm Paulus*, Vít Musil, Georg Martius and Michal Rolínek.

Autonomous Learning Group, Max Planck Institute for Intelligent Systems.

Table of Contents

  1. Introduction
  2. Installation
  3. Content
  4. Usage
  5. Notes

Introduction

This repository contains PyTorch modules that wrap blackbox combinatorial solver via the method proposed in Differentiation of Blackbox Combinatorial Solvers

Disclaimer: This code is a PROTOTYPE. It should work fine but use at your own risk.

For the exact usage of the combinatorial modules, see wider codebase

Installation

Simply install with pip

python3 -m pip install git+https://github.com/martius-lab/blackbox-backprop

For running the TSP module, a manual GurobiPy installation is required as well as a license

Content

Currently, the following solver modules are available (the list will be growing over time)

Combinatorial Problem Solver Paper
Travelling Salesman Cutting plane algorithm implemented in Gurobi Differentiation of Blackbox Combinatorial Solvers
Shortest Path (on a grid) Dijkstra algorithm (vertex version) Differentiation of Blackbox Combinatorial Solvers
Min-cost Perfect matching on general graphs Blossom V (Kolmogorov, 2009) Differentiation of Blackbox Combinatorial Solvers
Ranking (+ induced Recall & mAP loss functions) torch.argsort Blackbox Optimizationof Rank-Based Metrics

Usage

Exactly as you would expect of a PyTorch module (with minor deatils differing from solver to solver)

import blackbox_backprop as bb
...
suggested_weights = ResNet18(raw_inputs)
suggested_shortest_paths = bb.ShortestPath(suggested_weights, lambda_val=5.0) # Set the lambda hyperparameter
loss = HammingLoss(suggested_shortest_paths, true_shortest_paths) # Use e.g. Hamming distance as the loss function
loss.backward() # The backward pass is handled automatically
...

Notes

Contribute: If you spot a bug or some incompatibility, raise an issue or contribute via a pull request! Thank you!

About

Torch modules that wrap blackbox combinatorial solvers according to the method presented in "Differentiating Blackbox Combinatorial Solvers"

License:MIT License


Languages

Language:Python 100.0%