SaeedNajafi / pytorch-ocd

Implementation of the Optimal Completion Distillation for Sequence Labeling

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CircleCI

Optimal Completion Distillation (OCD) Training

Implementation of the Optimal Completion Distillation for Sequence Labeling
source : https://arxiv.org/abs/1810.01398

Requirements

python3, pytorch 1.0.0

Install

python3 -m venv env
source env/bin/activate
pip3 install .

How to use?

look at https://github.com/SaeedNajafi/pytorch-ocd/blob/master/ocd/__init__.py#L50 and
https://github.com/SaeedNajafi/pytorch-ocd/blob/master/tests/test_ocd.py#L132

from ocd import OCD

ocd_trainer = OCD(vocab_size=10, end_symbol_id=9)
...  # model defines scores for each step and each possible output token.
ocd_loss = ocd_trainer(model_scores, gold_output_sequence)
...  # backprop with ocd_loss

About

Implementation of the Optimal Completion Distillation for Sequence Labeling

License:MIT License


Languages

Language:Python 100.0%