sashank06 / THRED

The implementation of the paper "Augmenting Neural Response Generation with Context-Aware Topical Attention"

Home Page:https://arxiv.org/abs/1811.01063

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Codacy Badge

This repository hosts the implementation of the paper "Augmenting Neural Response Generation with Context-Aware Topical Attention".

Topical Hierarchical Recurrent Encoder Decoder (THRED)

THRED is a multi-turn response generation system intended to produce contextual and topic-aware responses. The codebase is evolved from the Tensorflow NMT repository.

TL;DR Steps to create a dialogue agent using this framework:

  1. Download the Reddit Conversation Corpus from here (7.8GB file containing triples extracted from Reddit). Please report errors/inappropriate content in the data here.
  2. Install the dependencies using conda env create -f thred_env.yml (To use pip, see Dependencies)
  3. Train the model using the following command (pretrained models will be published soon). Note that MODEL_DIR is a directory that the model will be saved into. We recommend to train on at least 2 GPUs, otherwise you can reduce the data size (by omitting conversations from the training file) and the model size (by modifying the config file).
python -m thred --mode train --config conf/thred_medium.yml --model_dir <MODEL_DIR> \
--train_data <TRAIN_DATA> --dev_data <DEV_DATA> --test_data <TEST_DATA>
  1. Chat with the trained model using:
python -m thred --mode interactive --model_dir <MODEL_DIR>

Dependencies

  • Python >= 3.5 (Recommended: 3.6)
  • Tensorflow == 1.12.0
  • Tensorflow-Hub
  • SpaCy >= 2.1.0
  • pymagnitude
  • tqdm
  • redis1
  • mistune1
  • emot1
  • Gensim1
  • prompt-toolkit2

1packages required only for parsing and cleaning the Reddit data. 2used only for testing dialogue models in command-line interactive mode

To install the dependencies using pip, run pip install -r requirements. And for Anaconda, run conda env create -f thred_env.yml (recommended). Once done with the dependencies, run pip install -e . to install the thred package.

Data

Our Reddit dataset, which we call Reddit Conversation Corpus (RCC), is collected from 95 selected subreddits (listed here). We processed Reddit for a 20 month-period ranging from November 2016 until August 2018 (excluding June 2017 and July 2017; we utilized these two months along with the October 2016 data to train an LDA model). Please see here for the details of how the Reddit dataset is built including pre-processing and cleaning the raw Reddit files. The following table summarizes the RCC information:

Corpus #train #dev #test Download Download with topic words
3 turns per line 9.2M 508K 406K download (1.9GB) download (7.8GB)
4 turns per line 4M 223K 178K download (1.1GB) download (3.7GB)
5 turns per line 1.8M 100K 80K download (609.7MB) download (1.8GB)

In the data files, each line corresponds to a single conversation where utterances are TAB-separated. The topic words appear after the last utterance separated also by a TAB.

Note that the 3-turns/4-turns/5-turns files contain similar content albeit with different number of utterances per line. They are all extracted from the same source. If you found any error or any inappropriate utterance in the data, please report your concerns here.

Embeddings

In the model config files (i.e., the YAML files in conf), the embedding types can be either of the following: glove840B, fastText, word2vec, and hub_word2vec. For handling the pre-trained embedding vectors, we leverage Pymagnitude and Tensorflow-Hub. Note that you can also use random300 (300 refers to the dimension of embedding vectors and can be replaced by any arbitrary value) to learn vectors during training of the response generation models. The settings related to embedding models are provided in word_embeddings.yml.

Train

The training configuration should be defined in a YAML file similar to Tensorflow NMT. Sample configurations for THRED and other baselines are provided here.

The implemented models are Seq2Seq, HRED, Topic Aware-Seq2Seq, and THRED.

Note that while most of the parameters are common among the different models, some models may have additional parameters (e.g., topical models have topic_words_per_utterance and boost_topic_gen_prob parameters).

To train a model, run the following command:

python main.py --mode train --config <YAML_FILE> \
--train_data <TRAIN_DATA> --dev_data <DEV_DATA> --test_data <TEST_DATA> \
--model_dir <MODEL_DIR>

In <MODEL_DIR>, vocabulary files and Tensorflow model files are stored. Training can be resumed by executing:

python main.py --mode train --model_dir <MODEL_DIR>

Test

With the following command, the model can be tested against the test dataset.

python main.py --mode test --model_dir <MODEL_DIR> --test_data <TEST_DATA>

It is possible to override test parameters during testing. These parameters are: beam width --beam_width, length penalty weight --length_penalty_weight, and sampling temperature --sampling_temperature.

A simple command line interface is implemented that allows you to converse with the learned model (Similar to test mode, the test parameters can be overrided too):

python main.py --mode interactive --model_dir <MODEL_DIR>

In the interactive mode, a pre-trained LDA model is required to feed the inferred topic words into the model. We trained an LDA model using Gensim on a Reddit corpus, collected for this purpose. It can be downloaded from here. The downloaded file should be uncompressed and passed to the program via --lda_model_dir <DIR>.

Citation

Please cite the following paper if you used our work in your research:

@article{dziri2018augmenting,
  title={Augmenting Neural Response Generation with Context-Aware Topical Attention},
  author={Dziri, Nouha and Kamalloo, Ehsan and Mathewson, Kory W and Zaiane, Osmar R},
  journal={arXiv preprint arXiv:1811.01063},
  year={2018}
}

About

The implementation of the paper "Augmenting Neural Response Generation with Context-Aware Topical Attention"

https://arxiv.org/abs/1811.01063

License:MIT License


Languages

Language:Python 98.8%Language:Shell 1.2%