NTT123 / soft-dtw-jax

Soft-DTW loss in JAX

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

sdtw-jax

Soft-DTW loss (with warp penalty) in JAX.

Usage:

x = jax.random.normal(jax.random.PRNGKey(42), (4, 800, 80))
y = jnp.roll(x, 4, axis=1)
batched_sdtw(x, y, warp_penalty=1.0, temperature=0.01)
# DeviceArray([16.947954, 16.809141, 16.411541, 17.066374], dtype=float32)

Source:

# Reference: https://arxiv.org/abs/2006.03575

import jax
import jax.numpy as jnp
import numpy as np
from functools import partial


def soft_minimum(values, temperature):
    values = jnp.stack(values, axis=-1)
    return -temperature * jax.nn.logsumexp(-values / temperature, axis=-1)


def skew_matrix(x):
    """Skew a matrix so that the diagonals become the rows."""
    clip = lambda x, a, b: min(max(x, a), b)
    height, width = x.shape
    ids = np.empty((height + width - 1, width), dtype=np.int32)
    for i in range(height + width - 1):
        for j in range(width):  # Shift each column j down by j steps.
            ids[i, j] = clip(i - j, 0, height - 1)
    x = jnp.take_along_axis(x, ids, axis=0)
    return x


def kernel_dist(kernel, xs, ys):
    """
    Returns:
    A 2d array `a` such that `a[i, j] = kernel(xs[i], ys[j])`.
    """
    return jax.vmap(lambda x: jax.vmap(lambda y: kernel(x, y))(ys))(xs)


@partial(jax.jit, static_argnums=(2, 3, 4))
def sdtw(a, b, warp_penalty=1.0, temperature=0.01, INFINITY=1e8):
    N, D1 = a.shape
    M, D2 = b.shape
    assert D1 == D2
    dist_fn = lambda x, y: jnp.mean(jnp.abs(x - y), axis=-1)
    cost = kernel_dist(dist_fn, a, b)
    size = cost.shape[-1]
    path_cost = INFINITY * jnp.ones((size + 1,))
    path_cost_prev = INFINITY * jnp.ones((size,))
    path_cost_prev = jnp.pad(path_cost_prev, (1, 0), constant_values=0.0)
    cost = skew_matrix(cost)

    def scan_fn(prev, inputs):
        path_cost_prev, path_cost = prev
        cost_i = inputs
        penalty_cost = path_cost + warp_penalty
        directions = [path_cost_prev[:-1], penalty_cost[1:], penalty_cost[:-1]]
        path_cost_next = cost_i + soft_minimum(directions, temperature)
        path_cost_next = jnp.pad(path_cost_next, (1, 0), constant_values=INFINITY)
        path_cost, path_cost_prev = path_cost_next, path_cost
        return (path_cost_prev, path_cost), None

    (path_cost_prev, path_cost), _ = jax.lax.scan(
        scan_fn, (path_cost_prev, path_cost), cost
    )
    return path_cost[-1]


@partial(jax.jit, static_argnums=(2, 3, 4))
def batched_sdtw(a, b, warp_penalty=1.0, temperature=0.01, INFINITY=1e8):
    dtw = partial(
        sdtw,
        warp_penalty=warp_penalty,
        temperature=temperature,
        INFINITY=INFINITY,
    )
    return jax.vmap(dtw)(a, b)

About

Soft-DTW loss in JAX

License:MIT License