tongpi / basicOCR

BasicOCR是一个致力于解决自然场景文字识别算法研究的项目。该项目由长城数字大数据应用技术研究院佟派AI团队发起和维护。

Home Page:https://tongpi.github.io/basicOCR/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How can l get the probability of the sequence outputted by CRNN ?

ahmedmazari-dhatim opened this issue · comments

Hello,

l'm wondering whether the CRNN is able to output also the probability of each sequence

from example :

--h-e--ll-oo- => 'hello' with a probability= 0.89
for instance
how can l get that ?

in the code CTCLoss can't find these probabilites .
However l don't find where to print the output probabilities in CTCloss(). In __init__.py the CTC class is defined as follow :

class _CTC(Function):
    def forward(self, acts, labels, act_lens, label_lens):
        is_cuda = True if acts.is_cuda else False
        acts = acts.contiguous()
        loss_func = warp_ctc.gpu_ctc if is_cuda else warp_ctc.cpu_ctc
        grads = torch.zeros(acts.size()).type_as(acts)
        minibatch_size = acts.size(1)
        costs = torch.zeros(minibatch_size)
        loss_func(acts,
                  grads,
                  labels,
                  label_lens,
                  act_lens,
                  minibatch_size,
                  costs)
        self.grads = grads
        self.costs = torch.FloatTensor([costs.sum()])
        return self.costs

    def backward(self, grad_output):
        return self.grads, None, None, None


class CTCLoss(Module):
    def __init__(self):
        super(CTCLoss, self).__init__()

    def forward(self, acts, labels, act_lens, label_lens):
        """
        acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
        """
        _assert_no_grad(labels)
        _assert_no_grad(act_lens)
        _assert_no_grad(label_lens)
        return _CTC()(acts, labels, act_lens, label_lens)

@ahmedmazari-dhatim Sorry, we did not do the relevant treatment

Hi @wulivicte ,
Thanks for your answer. So, from your answer l understand that there is no way to get the probabilities using the pytorch version ?

Thanks

@ahmedmazari-dhatim I'm not sure because our business has no need for this

@wulivicte thank you