iamjanvijay / rnnt

An implementation of RNN-Transducer loss in TF-2.0.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

RNN-Transducer Loss

This package provides a implementation of Transducer Loss in TensorFlow==2.0.

Using the pakage

First install the module using pip command.

pip install rnnt

Then use the "rnnt" loss funtion from "rnnt" module, as described in the sample script: Sample Train Script

from rnnt import rnnt_loss

def loss_grad_gradtape(logits, labels, label_lengths, logit_lengths):
    with tf.GradientTape() as g:
        g.watch(logits)
        loss = rnnt_loss(logits, labels, label_lengths, logit_lengths)
    grad = g.gradient(loss, logits)
    return loss, grad
    
pred_loss, pred_grads = loss_grad_gradtape(logits, labels, label_lengths, logit_lengths)

Follwing are the shapes of input parameters for rnnt_loss method -
logits - (batch_size, input_time_steps, output_time_steps+1, vocab_size+1)
labels - (batch_size, output_time_steps)
label_length - (batch_size) - number of time steps for each output sequence in the minibatch.
logit_length - (batch_size) - number of time steps for each input sequence in the minibatch.

About

An implementation of RNN-Transducer loss in TF-2.0.

License:MIT License


Languages

Language:Python 100.0%