githubharald / CTCWordBeamSearch

Connectionist Temporal Classification (CTC) decoder with dictionary and language model.

Home Page:https://harald-scheidl.medium.com/b051d28f3d2e

Repository from Github https://github.comgithubharald/CTCWordBeamSearchRepository from Github https://github.comgithubharald/CTCWordBeamSearch

Language model kills performance in beam search

janvainer opened this issue · comments

This is a continuation of a question on stack overflow. The question was Why should word level LM integration work, if it decreases the probability of valid prefixes while not scoring prefixes that haven't been decoded into words yet? I checked my dataset and I have OOV rate 8%. I also calculated the crossentropy for a bigram model trained on training set agains the test set. The results are as follows.

vocab size: 39105
oov rate: 0.08589909443725743
cross_entropy_train_test: 6.5233285695619125
train_entropy: 6.483290547147171
test_entropy: 6.314921857881477

Unfortunately, I can not share the distributions from the accoustic model and the corresponding text since the text is proprietary :/ I do not expect a direct solution or so, but rather a discussion about why word-level beam search should work. I am now running experiments with char-level LM and the results seem quite promissing. I use the following implementation of beam search.

def log_beam_search(ctc, alphabet, blank_idx, beam_width, lm=False, char_lm=False, alpha=0.3, beta=5,
                    prune=0, prefix_tree=False, end_symbol='>'):

    F = ctc.shape[1]
    ctc = np.vstack((np.zeros(F), ctc))
    T = ctc.shape[0]

    Pb = defaultdict(lambda: defaultdict(lambda : NEG_INF))
    Pnb = defaultdict(lambda: defaultdict(lambda : NEG_INF))
    Pb[0][''] = 0
    Pnb[0][''] = NEG_INF
    A_prev = ['']

    for t in range(1, T):
        if prune:
            pruned_alphabet = [alphabet[i] for i in np.where(ctc[t] > prune)[0]]
        else:
            pruned_alphabet = alphabet

        for l in A_prev:

            if len(l) > 0 and l[-1] == end_symbol:
                Pb[t][l] = Pb[t - 1][l]
                Pnb[t][l] = Pnb[t - 1][l]
                continue

            if prefix_tree:
                if len(l) > 0 and l[-1] != ' ':
                    pruned_alphabet = prefix_tree(l.split()[-1])

            for c in pruned_alphabet:
                c_idx = alphabet.index(c)  # todo: use dict to get O(log(n)) insted of O(n)

                if c_idx == blank_idx:

                    Pb[t][l] = logsumexp(
                        Pb[t][l],
                        ctc[t][blank_idx] + Pb[t - 1][l],
                        ctc[t][blank_idx] + Pnb[t - 1][l]
                    )

                else:

                    l_plus = l + c
                    if len(l) > 0 and c == l[-1]:
                        if char_lm:
                            ch = alpha * char_lm(l_plus)
                        else: ch = 0

                        Pnb[t][l_plus] = logsumexp(
                            Pnb[t][l_plus],
                            ctc[t][c_idx] + Pb[t - 1][l] + ch
                        )

                        Pnb[t][l] = logsumexp(
                            Pnb[t][l],
                            ctc[t][c_idx] + Pnb[t - 1][l]
                        )

                    elif len(l.replace(' ', '')) > 0 and c in (' ', end_symbol):

                        lm_prob = 0 if not lm else alpha * lm(l_plus.strip(' >'))

                        Pnb[t][l_plus] = logsumexp(
                                Pnb[t][l_plus],
                                lm_prob + ctc[t][c_idx] + Pb[t - 1][l],
                                lm_prob + ctc[t][c_idx] + Pnb[t - 1][l]
                            )

                    else:

                        if char_lm:
                            ch = alpha * char_lm(l_plus)
                        else: ch = 0

                        Pnb[t][l_plus] = logsumexp(
                                Pnb[t][l_plus],
                                ctc[t][c_idx] + Pb[t - 1][l] + ch,
                                ctc[t][c_idx] + Pnb[t - 1][l] + ch
                            )

                    # Make use of discarded prefixes
                    if l_plus not in A_prev:

                        Pb[t][l_plus] = logsumexp(
                            Pb[t][l_plus],
                            ctc[t][-1] + Pb[t - 1][l_plus],
                            ctc[t][-1] + Pnb[t - 1][l_plus]
                        )

                        Pnb[t][l_plus] = logsumexp(
                            Pnb[t][l_plus],
                            ctc[t][c_idx] + Pnb[t - 1][l_plus]
                        )

        A_next = {
            x: logsumexp(
                Pnb[t].get(x, NEG_INF),
                Pb[t].get(x, NEG_INF)
            )
            for x in set(Pb[t]).union(Pnb[t])
        }
        # word insertion bonus rescoring - do not use word bonus if no LM used!
        bonus = 0 if not lm else beta * math.log(len(words(l)) + 1)
        sorter = lambda l: A_next[l] + bonus
        A_prev = sorted(A_next, key=sorter, reverse=True)[:beam_width]
        del Pnb[t-1], Pb[t-1]

    return A_prev

sorry, I won't find the time to discuss this any further at the moment.