google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Home Page:http://jax.readthedocs.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Implement `scipy.optimize.linear_sum_assignment`

carlosgmartin opened this issue · comments

Implement scipy.optimize.linear_sum_assignment, which solves the assignment problem. Among other things, this is useful for estimating the Wasserstein distance between two distributions based on their empirical measures.

@hawkinsp I can take this up. Can you please provide bit more context where to add the function and any reference if possible?

Anyone working on this? It would be nice to have this in JAX.

For reference there is this paper On implementing 2D rectangular assignment algorithms mentioned in the SciPy documentation, that reviews many related algorithms, and a more recent paper, A Fast Scalable Solver for the Dense Linear (Sum) Assignment Problem that attempts to parallelise the algorithm.

@avinashsai @riversdark Here is scipy's C++ implementation. I ported it to JAX, though it's not in fully JITable form yet:

from itertools import count

from jax import numpy as jnp, random, jit
from jax.lax import cond, while_loop
from scipy.optimize import linear_sum_assignment

def augmenting_path(cost, u, v, path, row4col, i):
    minVal = 0
    num_remaining = cost.shape[1]
    remaining = jnp.arange(cost.shape[1])[::-1]

    SR = jnp.full(cost.shape[0], False)
    SC = jnp.full(cost.shape[1], False)
    shortestPathCosts = jnp.full(cost.shape[1], jnp.inf)

    sink = -1
    while sink == -1:
        index = -1
        lowest = jnp.inf
        SR = SR.at[i].set(True)

        for it in range(num_remaining):
            j = remaining[it]

            r = minVal + cost[i, j] - u[i] - v[j]

            path = cond(
                r < shortestPathCosts[j],
                lambda: path.at[j].set(i),
                lambda: path
            )
            shortestPathCosts = shortestPathCosts.at[j].min(r)

            index = cond(
                (shortestPathCosts[j] < lowest) | 
                ((shortestPathCosts[j] == lowest) & (row4col[j] == -1)),
                lambda: it,
                lambda: index
            )
            lowest = jnp.minimum(lowest, shortestPathCosts[j])

        minVal = lowest
        if minVal == jnp.inf: # infeasible cost matrix
            sink = -1
            break

        j = remaining[index]

        pred = row4col[j] == -1
        sink = cond(pred, lambda: j, lambda: sink)
        i = cond(~pred, lambda: row4col[j], lambda: i)

        SC = SC.at[j].set(True)
        num_remaining -= 1
        remaining = remaining.at[index].set(remaining[num_remaining])

    return sink, minVal, remaining, SR, SC, shortestPathCosts, path

def solve(cost):
    transpose = cost.shape[1] < cost.shape[0]

    if transpose:
        cost = cost.T

    u = jnp.full(cost.shape[0], 0.)
    v = jnp.full(cost.shape[1], 0.)
    path = jnp.full(cost.shape[1], -1)
    col4row = jnp.full(cost.shape[0], -1)
    row4col = jnp.full(cost.shape[1], -1)

    for curRow in range(cost.shape[0]):

        j, minVal, remaining, SR, SC, shortestPathCosts, path = augmenting_path(cost, u, v, path, row4col, curRow)

        u = u.at[curRow].add(minVal)

        mask = SR & (jnp.arange(cost.shape[0]) != curRow)
        u = u.at[mask].add(minVal - shortestPathCosts[col4row][mask])

        v = v.at[SC].add(shortestPathCosts[SC] - minVal)

        while True:
            i = path[j]
            row4col = row4col.at[j].set(i)

            col4row, j = col4row.at[i].set(j), col4row[i]

            if i == curRow:
                break

    if transpose:
        v = col4row.argsort()
        return col4row[v], v
    else:
        return jnp.arange(cost.shape[0]), col4row

def main():
    key = random.PRNGKey(0)
    for t in count():
        key, subkey = random.split(key)
        shape = random.randint(subkey, [2], 0, 6)

        key, subkey = random.split(key)
        cost = random.uniform(subkey, shape)

        if t < 0: # skip to failing case
            continue

        row_ind_1, col_ind_1 = linear_sum_assignment(cost)
        row_ind_2, col_ind_2 = solve(cost)

        print('{:5} {}'.format(t,
            (row_ind_1 == row_ind_2).all() and 
            (col_ind_1 == col_ind_2).all()
        ))

if __name__ == '__main__':
    main()
commented

Any updates on this? This seems particularly important to have for set to set machine learning methods (eg detr).

@avinashsai Are you still interested in implementing this?