noahchalifour / memn2n

An end-to-end goal oriented dialog model in Tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

MemN2N

This repository is a Python implementation of Facebook's research paper entitled LEARNING END-TO-END GOAL-ORIENTED DIALOG using Tensorflow 2.

bAbi Dialog

Training a model

To get started training a model first you must download the bAbi dialog dataset which can be found here.

Once you have the dataset extracted, run the following commands to setup the repository:

git clone https://github.com/noahchalifour/memn2n
cd memn2n
pip install tf-nightly==2.2.0.dev20200130 # or tf-nightly-gpu==2.2.0.dev20200130 for GPU version
pip install -r requirements.txt

Once you have the code setup, you can modify the model hyperparameters in the utils/hparams.py file. By default the hyperparameters are setup to run a bunch of test so if you just want to train a single model, set all hyperparameter values to a list of the single value you want to use.

After you've modified the hyperparameters (Optional), run the following command to start training:

python run_babi_dialog.py \
    --mode train \
    --task {{ task_id }} \
    --data_dir {{ babi_dialog_dir }}

Testing your model

Once you have a model trained, you can test your model by using the following command:

python run_babi_dialog.py \
    --mode test \
    --task {{ task_id }} \
    --data_dir {{ babi_dialog_dir }} \
    --model_dir {{ model_dir }} \
    --use_oov {{ true to test on OOV, false otherwise }}

Results

All of the following models were trained for 200 epochs with an embedding_size = 32, memory_size = 50, memory_hops = 3, learning_rate = 1e-03, and batch_size = 32. Better results can be achieved by tuning hyperparameters. (most notably embedding_size and memory_hops)

Task Original Paper This Repository
T1: Issuing API calls 99.9 99.9
T2: Updating API calls 100 99.9
T3: Displaying options 74.9 74.6
T4: Providing information 59.5 56.7
T5: Full dialogs 96.1 92.6
T1 (OOV): Issuing API calls 72.3 81.9
T2 (OOV): Updating API calls 78.9 78.8
T3 (OOV): Displaying options 74.4 69.2
T4 (OOV): Providing information 57.6 57.1
T5 (OOV): Full dialogs 65.5 63.0
T6: Dialog state tracking 2 41.1 39.2

About

An end-to-end goal oriented dialog model in Tensorflow

License:MIT License


Languages

Language:Python 100.0%