GRAAL-Research / deepparse

Deepparse is a state-of-the-art library for parsing multinational street addresses using deep learning

Home Page:https://deepparse.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Key Error while retraining on custom prediction tags[BUG]

ChargedMonk opened this issue · comments

Describe the bug

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-13-ba3132bfa7ad> in <module>()
----> 1 address_parser.retrain(training_container, 0.8, epochs=5, batch_size=2, num_workers=2, callbacks=[lr_scheduler], prediction_tags=tag_dictionary, logging_path=logging_path)

11 frames
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
    432             # instantiate since we don't know how to
    433             raise RuntimeError(msg) from None
--> 434         raise exception
    435 
    436 

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.7/dist-packages/deepparse/converter/data_transform.py", line 49, in teacher_forcing_transform
    vectorize_batch_pairs = self.vectorizer(batch_pairs)
  File "/usr/local/lib/python3.7/dist-packages/deepparse/vectorizer/train_vectorizer.py", line 25, in __call__
    input_sequence.extend(
  File "/usr/local/lib/python3.7/dist-packages/deepparse/vectorizer/train_vectorizer.py", line 25, in <listcomp>
    input_sequence.extend(
KeyError: 0

To Reproduce
Key error is raised when trying to retrain on new prediction tags

lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1)


tag_dictionary = {'STREET_NUMBER': 0, 'STREET_NAME': 1, 'UNSTRUCTURED_STREET_ADDRESS': 2, 'CITY': 3, 'COUNTRY_SUB_ENTITY': 4, 'COUNTRY': 5, 'POSTAL_CODE': 6, 'EOS': 7}

logging_path = "checkpoints"

address_parser.retrain(training_container, 0.8, epochs=5, batch_size=2, num_workers=2, callbacks=[lr_scheduler], prediction_tags=tag_dictionary, logging_path=logging_path)

Desktop (please complete the following information):

  • OS: Linux
    To be specific I'm running it on google colab.

The error looks more on your side:

It happened here

input_sequence.extend(
            self.embedding_vectorizer([address[0] for address in addresses])
        )  # need to be pass in batch

This means that some addresses in your batch are not tuples (empty maybe). To retrain, your data need to look like this: ('an address', [tag1_1, tag_2, ...]).

I will improve error handling to be more clear and add a clearer example.

See the new release for better error handling.