leftthomas / CCCapsNet

A PyTorch implementation of Compositional Coding Capsule Network based on PRL 2022 paper "Compositional Coding Capsule Network with K-Means Routing for Text Classification"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

CCCapsNet

A PyTorch implementation of Compositional Coding Capsule Network based on PRL 2022 paper Compositional Coding Capsule Network with K-Means Routing for Text Classification.

Requirements

conda install pytorch torchvision -c pytorch
  • PyTorchNet
pip install git+https://github.com/pytorch/tnt.git@master
  • PyTorch-NLP
pip install pytorch-nlp
  • capsule-layer
pip install git+https://github.com/leftthomas/CapsuleLayer.git@master

Datasets

The original AGNews, AmazonReview, DBPedia, YahooAnswers, SogouNews and YelpReview datasets are coming from here.

The original Newsgroups, Reuters, Cade and WebKB datasets can be found here.

The original IMDB dataset is downloaded by PyTorch-NLP automatically.

We have uploaded all the original datasets into BaiduYun(access code:kddr) and GoogleDrive. The preprocessed datasets have been uploaded to BaiduYun(access code:2kyd) and GoogleDrive.

You needn't download the datasets by yourself, the code will download them automatically. If you encounter network issues, you can download all the datasets from the aforementioned cloud storage webs, and extract them into data directory.

Usage

Generate Preprocessed Data

python utils.py --data_type yelp --fine_grained
optional arguments:
--data_type              dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb', 
                         'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon'])
--fine_grained           use fine grained class or not, it only works for reuters, yelp and amazon [default value is False]

This step is not required, and it takes a long time to execute. So I have generated the preprocessed data before, and uploaded them to the aforementioned cloud storage webs. You could skip this step, and just do the next step, the code will download the data automatically.

Train Text Classification

visdom -logging_level WARNING & python main.py --data_type newsgroups --num_epochs 70
optional arguments:
--data_type              dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb', 
                         'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon'])
--fine_grained           use fine grained class or not, it only works for reuters, yelp and amazon [default value is False]
--text_length            the number of words about the text to load [default value is 5000]
--routing_type           routing type, it only works for capsule classifier [default value is 'k_means'](choices:['k_means', 'dynamic'])
--loss_type              loss type [default value is 'mf'](choices:['margin', 'focal', 'cross', 'mf', 'mc', 'fc', 'mfc'])
--embedding_type         embedding type [default value is 'cwc'](choices:['cwc', 'cc', 'normal'])
--classifier_type        classifier type [default value is 'capsule'](choices:['capsule', 'linear'])
--embedding_size         embedding size [default value is 64]
--num_codebook           codebook number, it only works for cwc and cc embedding [default value is 8]
--num_codeword           codeword number, it only works for cwc and cc embedding [default value is None]
--hidden_size            hidden size [default value is 128]
--in_length              in capsule length, it only works for capsule classifier [default value is 8]
--out_length             out capsule length, it only works for capsule classifier [default value is 16]
--num_iterations         routing iterations number, it only works for capsule classifier [default value is 3]
--num_repeat             gumbel softmax repeat number, it only works for cc embedding [default value is 10]
--drop_out               drop_out rate of GRU layer [default value is 0.5]
--batch_size             train batch size [default value is 32]
--num_epochs             train epochs number [default value is 10]
--num_steps              test steps number [default value is 100]
--pre_model              pre-trained model weight, it only works for routing_type experiment [default value is None]

Visdom now can be accessed by going to 127.0.0.1:8097/env/$data_type in your browser, $data_type means the dataset type which you are training.

Benchmarks

Adam optimizer is used with learning rate scheduling. The models are trained with 10 epochs and batch size of 32 on one NVIDIA Tesla V100 (32G) GPU.

The texts are preprocessed as only number and English words, max length is 5000.

Here is the dataset details:

Dataset agnews dbpedia yahoo sogou yelp yelp fine grained amazon amazon fine grained
Num. of Train Texts 120,000 560,000 1,400,000 450,000 560,000 650,000 3,600,000 3,000,000
Num. of Test Texts 7,600 70,000 60,000 60,000 38,000 50,000 400,000 650,000
Num. of Vocabulary 62,535 548,338 771,820 106,385 200,790 216,985 931,271 835,818
Num. of Classes 4 14 10 5 2 5 2 5

Here is the model parameter details, the model name are formalized as embedding_type-classifier_type:

Dataset agnews dbpedia yahoo sogou yelp yelp fine grained amazon amazon fine grained
Normal-Linear 4,448,192 35,540,864 49,843,200 7,254,720 13,296,256 14,333,120 60,047,040 53,938,432
CC-Linear 2,449,120 26,770,528 37,497,152 4,704,040 8,479,856 9,128,040 45,149,776 40,568,416
CWC-Linear 2,449,120 26,770,528 37,497,152 4,704,040 8,479,856 9,128,040 45,149,776 40,568,416
Normal-Capsule 4,455,872 35,567,744 49,862,400 7,264,320 13,300,096 14,342,720 60,050,880 53,948,032
CC-Capsule 2,456,800 26,797,408 37,516,352 4,713,640 8,483,696 9,137,640 45,153,616 40,578,016
CWC-Capsule 2,456,800 26,797,408 37,516,352 4,713,640 8,483,696 9,137,640 45,153,616 40,578,016

Here is the loss function details, we use AGNews dataset and Normal-Linear model to test different loss functions:

Loss Function margin focal cross margin+focal margin+cross focal+cross margin+focal+cross
Accuracy 92.37% 92.13% 92.05% 92.64% 91.95% 92.09% 92.38%

Here is the accuracy details, we use margin+focal as our loss function, for capsule model, 3 iters is used, if embedding_type is CC, then plus num_repeat:

Dataset agnews dbpedia yahoo sogou yelp yelp fine grained amazon amazon fine grained
Normal-Linear 92.64% 98.84% 74.13% 97.37% 96.69% 66.23% 95.09% 60.78%
CC-Linear-10 73.11% 92.66% 48.01% 93.50% 87.81% 50.33% 83.20% 45.77%
CC-Linear-30 81.05% 95.29% 53.50% 94.65% 91.33% 55.22% 87.37% 50.00%
CC-Linear-50 83.13% 96.06% 57.87% 95.20% 92.37% 56.66% 89.04% 51.30%
CWC-Linear 91.93% 98.83% 73.58% 97.37% 96.35% 65.11% 94.90% 60.29%
Normal-Capsule 92.18% 98.86% 74.12% 97.52% 96.56% 66.23% 95.18% 61.36%
CC-Capsule-10 73.53% 93.04% 50.52% 94.44% 87.98% 54.14% 83.64% 47.44%
CC-Capsule-30 81.71% 95.72% 60.48% 95.96% 91.90% 58.27% 87.88% 51.63%
CC-Capsule-50 84.05% 96.27% 60.31% 96.00% 92.82% 59.48% 89.07% 52.06%
CWC-Capsule 92.12% 98.81% 73.78% 97.42% 96.28% 65.38% 94.98% 60.94%

Here is the model parameter details, we use CWC-Capsule as our model, the model name are formalized as num_codewords for each dataset:

Dataset agnews dbpedia yahoo sogou yelp yelp fine grained amazon amazon fine grained
57766677 2,957,592 31,184,624 43,691,424 5,565,232 10,090,528 10,874,032 52,604,296 47,265,072
68877788 3,458,384 35,571,840 49,866,496 6,416,824 11,697,360 12,610,424 60,054,976 53,952,128

Here is the accuracy details:

Dataset agnews dbpedia yahoo sogou yelp yelp fine grained amazon amazon fine grained
57766677 92.54% 98.85% 73.96% 97.41% 96.38% 65.86% 94.98% 60.98%
68877788 92.05% 98.82% 73.93% 97.52% 96.44% 65.63% 95.05% 61.02%

Here is the accuracy details, we use 57766677 config, the model name are formalized as num_iterations:

Dataset agnews dbpedia yahoo sogou yelp yelp fine grained amazon amazon fine grained
1 92.28% 98.82% 73.93% 97.25% 96.58% 65.60% 95.00% 61.08%
3 92.54% 98.85% 73.96% 97.41% 96.38% 65.86% 94.98% 60.98%
5 92.21% 98.88% 73.85% 97.38% 96.38% 65.36% 95.05% 61.23%

Results

The train/test loss、accuracy and confusion matrix are showed with visdom. The pretrained models and more results can be found in BaiduYun (access code:xer4) and GoogleDrive.

agnews

result

dbpedia

result

yahoo

result

sogou

result

yelp

result

yelp fine grained

result

amazon

result

amazon fine grained

result

About

A PyTorch implementation of Compositional Coding Capsule Network based on PRL 2022 paper "Compositional Coding Capsule Network with K-Means Routing for Text Classification"


Languages

Language:Python 100.0%