hobson / Text-Classification-CNN-PyTorch

Tweet text classification with PyTorch 1-D CNN. Inspired by [Fernando Lopez's](https://github.com/FernandoLpz/Text-Classification-CNN-PyTorch) incorrect implementation of "Convolutional Neural Networks for Sentence Classification"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Substack Nitter Mastodon Linkedin

Corrected Text Classification with CNNs

Text Classification with CNNs in PyTorch

Re-implementation, simplification of Fernando Lopez's (@FernandoLpz) Text Classification CNN which he based on the paper "Convolutional Neural Networks for Sentence Classification." He has some nice diagrams in his Medium (PAYWALL) blog post Text Classification with CNNs in PyTorch .

Preprocessing

For the book we improved the Fernando's pipeline to make it more readable, pythonic, and accurate. Just as in previous chapters (and in general), you do not want to use case folding or stopword removal for your first model. Give your model all the information you can for the first attempt at training it. You only need to filter stopwords and lowercase your text if you have a dimensionality problem (information overload). You can tune your model later to deal with overfitting or slow convergence. Usually there are more effective approaches to dealing with overfitting.

The preprocessing pipeline here is the same one we used for chapters 1-6 in NLPiA. You can see where we've changed the implementation suggested by Fernando.

  1. Simplified: load the CSV

df = pd.read_csv(filepath, usecols=['text', 'target'])

  1. NOPE: case folding:

texts = df['texts'].lower()

  1. NOPE: remove non-letters (nonalpha):

re.sub(r'[^A-Za-z]', text, ' ')

  1. NOPE: remove stopwords

  2. Simplified: tokenize with regex

tokenized_texts = map(re.compile(r'\w+').findall, texts)

  1. Simplified: filter infrequent words

counts = Counter(chain(*tokenized_texts)) id2tok = vocab = list(counts.most_common(200))

  1. Simplified: compute reverse index

tok2id = dict(zip(idx2tok, range(len(id2tok))))

  1. Simplified: transform token sequences to integer id sequences

id_sequences = [[tok2id[tok] for tok in tokens] for tokens in tokenized_texts]

  1. Simplified: pad token id sequences

  2. Simplified: train_test_split

x_train, y_train, x_test, y_test = train_test_split(x=texts, y=df['targets'])

This is common in PyTorch implementations, but the integers are never used as the input to a Deep Learning model because the numerical values of word indices do not contain any information about the words themselves other than their position in a one-hot vector. The index values are arbitrary and destroy all the information contained in a sentence. The numerical values of word indices contain no information about the meaning of the words. Word indices are fundamentally a categorical variable, not an ordinal or numerical value. Fernando's code effectively randomizes the 1-D numerical representation of a word thus the sequence of indices passed into the CNN is meaningless.

This can be seen in the poor accuracy on the test set compared to the training set. The test set accuracy is little better than random guessing or always guessing the majority class. In pyTorch a word embeddings layers is typically utilized to reconstruct a meaningful vector.

This repository attempts to correct this fundamental error by replacing word indices with progressively more informative numerical representations of words:

  • removing the CNN entirely
  • 2-D CNN on one-hot vectors
  • 2-D CNN on TF-IDF vectors
  • 2-D CNN on LSA vectors (PCA on TF-IDF vectors)
  • 2-D CNN on word vectors

In addition, a baseline LinearModel (a single-layer perceptron, or fully-connected feed-forward layer) is trained on the average of the word vectors for a sentence or tweet (spacy.Doc.vector).

Links

About

Tweet text classification with PyTorch 1-D CNN. Inspired by [Fernando Lopez's](https://github.com/FernandoLpz/Text-Classification-CNN-PyTorch) incorrect implementation of "Convolutional Neural Networks for Sentence Classification"

License:MIT License


Languages

Language:Python 100.0%