yashmanne / Intra-Distillation-1

This is the repository for our EMNLP 2022 paper "The Importance of Being Parameters: An Intra-Distillation Method for Serious Gains".

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Intra-Distillation

This is the repository for our EMNLP 2022 paper "The Importance of Being Parameters: An Intra-Distillation Method for Serious Gains".

@article{xu2022importance,
  title={The Importance of Being Parameters: An Intra-Distillation Method for Serious Gains},
  author={Xu, Haoran and Koehn, Philipp and Murray, Kenton},
  journal={arXiv preprint arXiv:2205.11416},
  year={2022}
}

Reproduction

We consider three tasks in our paper. Please visit the corresponding folder and follow the instruction to reproduce the results.

Model Card

Intra-Distillation is easy to implement, we here provide a model card for eaiser takeaway.

X-divergence

Given K logits in a list logits and padding masking pad_mask, we have

def X_loss(logits, pad_mask):
    pad_mask = pad_mask.view(-1)
    non_pad_mask = ~pad_mask
    dict_size = logits[0].size(-1)

    m = sum(logits) / len(logits)
    m = m.float().view(-1, dict_size)[non_pad_mask]

    kl_all = 0
    for l in logits:
        l = l.float().view(-1, dict_size)[non_pad_mask]
        d = (l-m) * (torch.log(l) - torch.log(m))
        kl_all += d.sum()
    return kl_all / len(logits)

Adaptive Alpha

Given max alpha, current step num_update, max step max_update, p and q, we have:

def _get_alpha(alpha, num_update, max_update, p, q):
    if num_update >= max_update / p or alpha <= 1:
        return alpha
    else:
        alpha = torch.tensor([alpha])
        gamma = torch.log(1/alpha) / torch.log(torch.tensor([p/q])) # log_(p/q)(1/alpha)
        new_alpha = ( p**gamma * alpha * num_update ** gamma) / (max_update ** gamma)
        return new_alpha.item()

About

This is the repository for our EMNLP 2022 paper "The Importance of Being Parameters: An Intra-Distillation Method for Serious Gains".


Languages

Language:Python 94.5%Language:Shell 3.9%Language:Cuda 0.8%Language:C++ 0.5%Language:Cython 0.2%Language:Lua 0.1%Language:Perl 0.0%Language:C 0.0%Language:Batchfile 0.0%Language:Makefile 0.0%