A PyTorch implementation of Compositional Coding Capsule Network based on PRL 2022 paper Compositional Coding Capsule Network with K-Means Routing for Text Classification.
- Anaconda
- PyTorch
conda install pytorch torchvision -c pytorch
- PyTorchNet
pip install git+https://github.com/pytorch/tnt.git@master
- PyTorch-NLP
pip install pytorch-nlp
- capsule-layer
pip install git+https://github.com/leftthomas/CapsuleLayer.git@master
The original AGNews
, AmazonReview
, DBPedia
, YahooAnswers
, SogouNews
and YelpReview
datasets are coming from here.
The original Newsgroups
, Reuters
, Cade
and WebKB
datasets can be found here.
The original IMDB
dataset is downloaded by PyTorch-NLP
automatically.
We have uploaded all the original datasets into BaiduYun(access code:kddr) and GoogleDrive. The preprocessed datasets have been uploaded to BaiduYun(access code:2kyd) and GoogleDrive.
You needn't download the datasets by yourself, the code will download them automatically.
If you encounter network issues, you can download all the datasets from the aforementioned cloud storage webs,
and extract them into data
directory.
python utils.py --data_type yelp --fine_grained
optional arguments:
--data_type dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb',
'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon'])
--fine_grained use fine grained class or not, it only works for reuters, yelp and amazon [default value is False]
This step is not required, and it takes a long time to execute. So I have generated the preprocessed data before, and uploaded them to the aforementioned cloud storage webs. You could skip this step, and just do the next step, the code will download the data automatically.
visdom -logging_level WARNING & python main.py --data_type newsgroups --num_epochs 70
optional arguments:
--data_type dataset type [default value is 'imdb'](choices:['imdb', 'newsgroups', 'reuters', 'webkb',
'cade', 'dbpedia', 'agnews', 'yahoo', 'sogou', 'yelp', 'amazon'])
--fine_grained use fine grained class or not, it only works for reuters, yelp and amazon [default value is False]
--text_length the number of words about the text to load [default value is 5000]
--routing_type routing type, it only works for capsule classifier [default value is 'k_means'](choices:['k_means', 'dynamic'])
--loss_type loss type [default value is 'mf'](choices:['margin', 'focal', 'cross', 'mf', 'mc', 'fc', 'mfc'])
--embedding_type embedding type [default value is 'cwc'](choices:['cwc', 'cc', 'normal'])
--classifier_type classifier type [default value is 'capsule'](choices:['capsule', 'linear'])
--embedding_size embedding size [default value is 64]
--num_codebook codebook number, it only works for cwc and cc embedding [default value is 8]
--num_codeword codeword number, it only works for cwc and cc embedding [default value is None]
--hidden_size hidden size [default value is 128]
--in_length in capsule length, it only works for capsule classifier [default value is 8]
--out_length out capsule length, it only works for capsule classifier [default value is 16]
--num_iterations routing iterations number, it only works for capsule classifier [default value is 3]
--num_repeat gumbel softmax repeat number, it only works for cc embedding [default value is 10]
--drop_out drop_out rate of GRU layer [default value is 0.5]
--batch_size train batch size [default value is 32]
--num_epochs train epochs number [default value is 10]
--num_steps test steps number [default value is 100]
--pre_model pre-trained model weight, it only works for routing_type experiment [default value is None]
Visdom now can be accessed by going to 127.0.0.1:8097/env/$data_type
in your browser, $data_type
means the dataset
type which you are training.
Adam optimizer is used with learning rate scheduling. The models are trained with 10 epochs and batch size of 32 on one NVIDIA Tesla V100 (32G) GPU.
The texts are preprocessed as only number and English words, max length is 5000.
Here is the dataset details:
Dataset | agnews | dbpedia | yahoo | sogou | yelp | yelp fine grained | amazon | amazon fine grained |
---|---|---|---|---|---|---|---|---|
Num. of Train Texts | 120,000 | 560,000 | 1,400,000 | 450,000 | 560,000 | 650,000 | 3,600,000 | 3,000,000 |
Num. of Test Texts | 7,600 | 70,000 | 60,000 | 60,000 | 38,000 | 50,000 | 400,000 | 650,000 |
Num. of Vocabulary | 62,535 | 548,338 | 771,820 | 106,385 | 200,790 | 216,985 | 931,271 | 835,818 |
Num. of Classes | 4 | 14 | 10 | 5 | 2 | 5 | 2 | 5 |
Here is the model parameter details, the model name are formalized as embedding_type-classifier_type
:
Dataset | agnews | dbpedia | yahoo | sogou | yelp | yelp fine grained | amazon | amazon fine grained |
---|---|---|---|---|---|---|---|---|
Normal-Linear | 4,448,192 | 35,540,864 | 49,843,200 | 7,254,720 | 13,296,256 | 14,333,120 | 60,047,040 | 53,938,432 |
CC-Linear | 2,449,120 | 26,770,528 | 37,497,152 | 4,704,040 | 8,479,856 | 9,128,040 | 45,149,776 | 40,568,416 |
CWC-Linear | 2,449,120 | 26,770,528 | 37,497,152 | 4,704,040 | 8,479,856 | 9,128,040 | 45,149,776 | 40,568,416 |
Normal-Capsule | 4,455,872 | 35,567,744 | 49,862,400 | 7,264,320 | 13,300,096 | 14,342,720 | 60,050,880 | 53,948,032 |
CC-Capsule | 2,456,800 | 26,797,408 | 37,516,352 | 4,713,640 | 8,483,696 | 9,137,640 | 45,153,616 | 40,578,016 |
CWC-Capsule | 2,456,800 | 26,797,408 | 37,516,352 | 4,713,640 | 8,483,696 | 9,137,640 | 45,153,616 | 40,578,016 |
Here is the loss function details, we use AGNews
dataset and Normal-Linear
model to test different loss functions:
Loss Function | margin | focal | cross | margin+focal | margin+cross | focal+cross | margin+focal+cross |
---|---|---|---|---|---|---|---|
Accuracy | 92.37% | 92.13% | 92.05% | 92.64% | 91.95% | 92.09% | 92.38% |
Here is the accuracy details, we use margin+focal
as our loss function, for capsule
model, 3 iters
is used,
if embedding_type
is CC
, then plus num_repeat
:
Dataset | agnews | dbpedia | yahoo | sogou | yelp | yelp fine grained | amazon | amazon fine grained |
---|---|---|---|---|---|---|---|---|
Normal-Linear | 92.64% | 98.84% | 74.13% | 97.37% | 96.69% | 66.23% | 95.09% | 60.78% |
CC-Linear-10 | 73.11% | 92.66% | 48.01% | 93.50% | 87.81% | 50.33% | 83.20% | 45.77% |
CC-Linear-30 | 81.05% | 95.29% | 53.50% | 94.65% | 91.33% | 55.22% | 87.37% | 50.00% |
CC-Linear-50 | 83.13% | 96.06% | 57.87% | 95.20% | 92.37% | 56.66% | 89.04% | 51.30% |
CWC-Linear | 91.93% | 98.83% | 73.58% | 97.37% | 96.35% | 65.11% | 94.90% | 60.29% |
Normal-Capsule | 92.18% | 98.86% | 74.12% | 97.52% | 96.56% | 66.23% | 95.18% | 61.36% |
CC-Capsule-10 | 73.53% | 93.04% | 50.52% | 94.44% | 87.98% | 54.14% | 83.64% | 47.44% |
CC-Capsule-30 | 81.71% | 95.72% | 60.48% | 95.96% | 91.90% | 58.27% | 87.88% | 51.63% |
CC-Capsule-50 | 84.05% | 96.27% | 60.31% | 96.00% | 92.82% | 59.48% | 89.07% | 52.06% |
CWC-Capsule | 92.12% | 98.81% | 73.78% | 97.42% | 96.28% | 65.38% | 94.98% | 60.94% |
Here is the model parameter details, we use CWC-Capsule
as our model, the model name are formalized as num_codewords
for each dataset:
Dataset | agnews | dbpedia | yahoo | sogou | yelp | yelp fine grained | amazon | amazon fine grained |
---|---|---|---|---|---|---|---|---|
57766677 | 2,957,592 | 31,184,624 | 43,691,424 | 5,565,232 | 10,090,528 | 10,874,032 | 52,604,296 | 47,265,072 |
68877788 | 3,458,384 | 35,571,840 | 49,866,496 | 6,416,824 | 11,697,360 | 12,610,424 | 60,054,976 | 53,952,128 |
Here is the accuracy details:
Dataset | agnews | dbpedia | yahoo | sogou | yelp | yelp fine grained | amazon | amazon fine grained |
---|---|---|---|---|---|---|---|---|
57766677 | 92.54% | 98.85% | 73.96% | 97.41% | 96.38% | 65.86% | 94.98% | 60.98% |
68877788 | 92.05% | 98.82% | 73.93% | 97.52% | 96.44% | 65.63% | 95.05% | 61.02% |
Here is the accuracy details, we use 57766677
config, the model name are formalized as num_iterations
:
Dataset | agnews | dbpedia | yahoo | sogou | yelp | yelp fine grained | amazon | amazon fine grained |
---|---|---|---|---|---|---|---|---|
1 | 92.28% | 98.82% | 73.93% | 97.25% | 96.58% | 65.60% | 95.00% | 61.08% |
3 | 92.54% | 98.85% | 73.96% | 97.41% | 96.38% | 65.86% | 94.98% | 60.98% |
5 | 92.21% | 98.88% | 73.85% | 97.38% | 96.38% | 65.36% | 95.05% | 61.23% |
The train/test loss、accuracy and confusion matrix are showed with visdom. The pretrained models and more results can be found in BaiduYun (access code:xer4) and GoogleDrive.
agnews
dbpedia
yahoo
sogou
yelp
yelp fine grained
amazon
amazon fine grained