A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications. (or an implementation of Inside-Outside and Forward-Backward Algorithms Are Just Backprop")
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
!pip install -q matplotlib
import torch
from torch_struct import DepTree, LinearChain, MaxSemiring, SampledSemiring
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5)
vals = vals.log()
show(vals[0])
# Compute marginals
marginals = DepTree().marginals(vals)
show(marginals[0])
# Compute argmax
argmax = DepTree(MaxSemiring).marginals(vals)
show(argmax.detach()[0])
# Compute scoring and enumeration (forward / inside)
log_partition = DepTree().sum(vals)
max_score = DepTree(MaxSemiring).sum(vals)
max_score = DepTree().score(argmax, vals)
# Compute samples
sample = DepTree(SampledSemiring).marginals(vals)
show(sample.detach()[0])
# Padding/Masking built into library.
marginals = DepTree().marginals(
vals,
lengths=torch.tensor([10, 7]))
show(marginals[0])
plt.show()
show(marginals[1])
# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10)
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()
marginals = LinearChain().marginals(chain)
show(marginals.detach()[0].sum(-1))
Current algorithms implemented:
-
Linear Chain (CRF / HMM)
-
Semi-Markov (CRF / HSMM)
-
Dependency Parsing (Projective and Non-Projective)
-
CKY (CFG)
-
Integration with
torchtext
andpytorch-transformers
Design Strategy:
- Minimal implementatations. Most are 10 lines.
- Batched for GPU.
- Code can be ported to other backends
Semirings:
- Log Marginals
- Max and MAP computation
- Sampling through specialized backprop
- BERT Part-of-Speech
- BERT Dependency Parsing
- Unsupervised Learning
- Structured VAE (to come)
- Structured attention (to come)