Implement `scipy.optimize.linear_sum_assignment`
carlosgmartin opened this issue · comments
Implement scipy.optimize.linear_sum_assignment
, which solves the assignment problem. Among other things, this is useful for estimating the Wasserstein distance between two distributions based on their empirical measures.
@hawkinsp I can take this up. Can you please provide bit more context where to add the function and any reference if possible?
Anyone working on this? It would be nice to have this in JAX.
For reference there is this paper On implementing 2D rectangular assignment algorithms mentioned in the SciPy documentation, that reviews many related algorithms, and a more recent paper, A Fast Scalable Solver for the Dense Linear (Sum) Assignment Problem that attempts to parallelise the algorithm.
@avinashsai @riversdark Here is scipy's C++ implementation. I ported it to JAX, though it's not in fully JITable form yet:
from itertools import count
from jax import numpy as jnp, random, jit
from jax.lax import cond, while_loop
from scipy.optimize import linear_sum_assignment
def augmenting_path(cost, u, v, path, row4col, i):
minVal = 0
num_remaining = cost.shape[1]
remaining = jnp.arange(cost.shape[1])[::-1]
SR = jnp.full(cost.shape[0], False)
SC = jnp.full(cost.shape[1], False)
shortestPathCosts = jnp.full(cost.shape[1], jnp.inf)
sink = -1
while sink == -1:
index = -1
lowest = jnp.inf
SR = SR.at[i].set(True)
for it in range(num_remaining):
j = remaining[it]
r = minVal + cost[i, j] - u[i] - v[j]
path = cond(
r < shortestPathCosts[j],
lambda: path.at[j].set(i),
lambda: path
)
shortestPathCosts = shortestPathCosts.at[j].min(r)
index = cond(
(shortestPathCosts[j] < lowest) |
((shortestPathCosts[j] == lowest) & (row4col[j] == -1)),
lambda: it,
lambda: index
)
lowest = jnp.minimum(lowest, shortestPathCosts[j])
minVal = lowest
if minVal == jnp.inf: # infeasible cost matrix
sink = -1
break
j = remaining[index]
pred = row4col[j] == -1
sink = cond(pred, lambda: j, lambda: sink)
i = cond(~pred, lambda: row4col[j], lambda: i)
SC = SC.at[j].set(True)
num_remaining -= 1
remaining = remaining.at[index].set(remaining[num_remaining])
return sink, minVal, remaining, SR, SC, shortestPathCosts, path
def solve(cost):
transpose = cost.shape[1] < cost.shape[0]
if transpose:
cost = cost.T
u = jnp.full(cost.shape[0], 0.)
v = jnp.full(cost.shape[1], 0.)
path = jnp.full(cost.shape[1], -1)
col4row = jnp.full(cost.shape[0], -1)
row4col = jnp.full(cost.shape[1], -1)
for curRow in range(cost.shape[0]):
j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)
u = u.at[curRow].add(minVal)
mask = SR & (jnp.arange(cost.shape[0]) != curRow)
u = u.at[mask].add(minVal - shortestPathCosts[col4row][mask])
v = v.at[SC].add(shortestPathCosts[SC] - minVal)
while True:
i = path[j]
row4col = row4col.at[j].set(i)
col4row, j = col4row.at[i].set(j), col4row[i]
if i == curRow:
break
if transpose:
v = col4row.argsort()
return col4row[v], v
else:
return jnp.arange(cost.shape[0]), col4row
def main():
key = random.PRNGKey(0)
for t in count():
key, subkey = random.split(key)
shape = random.randint(subkey, [2], 0, 6)
key, subkey = random.split(key)
cost = random.uniform(subkey, shape)
if t < 0: # skip to failing case
continue
row_ind_1, col_ind_1 = linear_sum_assignment(cost)
row_ind_2, col_ind_2 = solve(cost)
print('{:5} {}'.format(t,
(row_ind_1 == row_ind_2).all() and
(col_ind_1 == col_ind_2).all()
))
if __name__ == '__main__':
main()
Any updates on this? This seems particularly important to have for set to set machine learning methods (eg detr).
@avinashsai Are you still interested in implementing this?