jiajingchen113322 / CIS700_Project

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CIS 700 Project

This is the repository for course CIS 700 Deeplearning Theorem Proving's term project. We trained a GGNN using task 15 of bAbI datast.

Team members

Dependencies

  • pytorch
  • tensorboard
  • PyYAML
  • numpy
  • tqdm

If you want to use gpu to accelerate the training process, then please make sure that cuda is installed prior applying option --device cuda to the execution command.

Usage

To run the program, simply execute the following command:

python main.py

Several options can be applied to this command. The options are list as following:

  • --exp_name: Name of the experiment. A directory within Experiment will be created with this name to record the training logs. Default: default.
  • --epochs: Number of training epochs. Default: 50.
  • --train: Flag for training. Default True
  • --data_path: Path for the data. Default: babi_data/processed_1/train/15_graphs.txt
    • Depend on the OS that you are using, this option may need corresponding adjustment.
  • --batch_size: Batch size. Default: 15
  • --lr: Learning rate. Default: 0.01
  • --device: Device used for training model. Default: cuda
    • Again, if this option is set to cuda, please make sure that cuda is installed.
  • --opt: Optimizer used in training. Options:SGD, Adam. Default: Adam
  • --state_dim: Dimension for states. Default: 4
  • --annotation_dim: Annotation dimension. Default: 1
  • --edge_type: Types of edges in the graph. Default: 2
  • --n_nodes: Default: 8
  • --n_step: Number of time that propagation is done. Default: 5
  • --attention: If given, the GGNN with attention is used, otherwise the original GGNN is used.

For viewing training curve after training, please execute this command in the project's root directory.

tensorboard --logdir Experiment

Acknowledgement

The model and data used in this repository is based on this repository.
The original bAbI project can be found here.

About


Languages

Language:Python 93.7%Language:Shell 6.3%