MidoriYakumo / pytorch-struct

A library of vectorized implementations of core structured prediction algorithms (HMM, Dep Trees, CKY, ..,)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Pytorch-Struct

Build Status Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications. (or an implementation of Inside-Outside and Forward-Backward Algorithms Are Just Backprop")

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
!pip install -q matplotlib
import torch
from torch_struct import DepTree, LinearChain, MaxSemiring, SampledSemiring
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
vals = vals.log()
show(vals[0])

png

# Compute marginals
marginals = DepTree().marginals(vals)
show(marginals[0])

png

# Compute argmax
argmax = DepTree(MaxSemiring).marginals(vals)
show(argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = DepTree().sum(vals)
max_score = DepTree(MaxSemiring).sum(vals)
max_score = DepTree().score(argmax, vals)
# Compute samples 
sample = DepTree(SampledSemiring).marginals(vals)
show(sample.detach()[0])

png

# Padding/Masking built into library.
marginals = DepTree().marginals(
    vals,
    lengths=torch.tensor([10, 7]))
show(marginals[0])
plt.show()
show(marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

marginals = LinearChain().marginals(chain)
show(marginals.detach()[0].sum(-1))

png

Library

Current algorithms implemented:

  • Linear Chain (CRF / HMM)

  • Semi-Markov (CRF / HSMM)

  • Dependency Parsing (Projective and Non-Projective)

  • CKY (CFG)

  • Integration with torchtext and pytorch-transformers

Design Strategy:

  1. Minimal implementatations. Most are 10 lines.
  2. Batched for GPU.
  3. Code can be ported to other backends

Semirings:

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop

Examples

About

A library of vectorized implementations of core structured prediction algorithms (HMM, Dep Trees, CKY, ..,)

License:MIT License


Languages

Language:Jupyter Notebook 91.4%Language:Python 8.5%Language:Shell 0.0%