Key Error while retraining on custom prediction tags[BUG]
ChargedMonk opened this issue · comments
Vatsalya Bajpai commented
Describe the bug
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-13-ba3132bfa7ad> in <module>()
----> 1 address_parser.retrain(training_container, 0.8, epochs=5, batch_size=2, num_workers=2, callbacks=[lr_scheduler], prediction_tags=tag_dictionary, logging_path=logging_path)
11 frames
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
432 # instantiate since we don't know how to
433 raise RuntimeError(msg) from None
--> 434 raise exception
435
436
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.7/dist-packages/deepparse/converter/data_transform.py", line 49, in teacher_forcing_transform
vectorize_batch_pairs = self.vectorizer(batch_pairs)
File "/usr/local/lib/python3.7/dist-packages/deepparse/vectorizer/train_vectorizer.py", line 25, in __call__
input_sequence.extend(
File "/usr/local/lib/python3.7/dist-packages/deepparse/vectorizer/train_vectorizer.py", line 25, in <listcomp>
input_sequence.extend(
KeyError: 0
To Reproduce
Key error is raised when trying to retrain on new prediction tags
lr_scheduler = poutyne.StepLR(step_size=1, gamma=0.1)
tag_dictionary = {'STREET_NUMBER': 0, 'STREET_NAME': 1, 'UNSTRUCTURED_STREET_ADDRESS': 2, 'CITY': 3, 'COUNTRY_SUB_ENTITY': 4, 'COUNTRY': 5, 'POSTAL_CODE': 6, 'EOS': 7}
logging_path = "checkpoints"
address_parser.retrain(training_container, 0.8, epochs=5, batch_size=2, num_workers=2, callbacks=[lr_scheduler], prediction_tags=tag_dictionary, logging_path=logging_path)
Desktop (please complete the following information):
- OS: Linux
To be specific I'm running it on google colab.
David Beauchemin commented
The error looks more on your side:
It happened here
input_sequence.extend(
self.embedding_vectorizer([address[0] for address in addresses])
) # need to be pass in batch
This means that some addresses in your batch are not tuples (empty maybe). To retrain, your data need to look like this: ('an address', [tag1_1, tag_2, ...])
.
I will improve error handling to be more clear and add a clearer example.
David Beauchemin commented
See the new release for better error handling.