Capsule Network implementation in Tensorflow based on Geoffrey Hinton's paper Dynamic Routing Between Capsules.
Fig1. Capsule Network architecture from Hinton's paper
I started learning about CapsNet by reading the paper and watch Hinton's talks (like This one). While they are fascinating, they give very limited information (most likely due to papers page limitation and talk's time limitation). So to get the detailed picture, I watched many videos and read many blogs. My suggestions to fully master the Capsule Network's theory are the following sources (of course other than Hinton's paper):
- Max Pechyonkin's blog series
- Aurélien Géron's videos on Capsule Networks – Tutorial and How to implement CapsNets using TensorFlow
This code is partly inspired from Liao's implementation CapsNet-Tensorflow with changes applied to add some features as well as making the code more efficient.
The main changes include:
- The Capslayer class is removed as I found it to be unnecessary at this level. This makes the whole code shorter and structured.
- Hard-coded values (such as the number of capsules, hidden units, etc.) are all extracted and are accessible through
config.py
file. Therefore, making changes in the network structure is much more convenient. - Summary writers are modified. Liao's code writes the loss and accuracy only for the final batch after a certain desired steps. Here, it's modified to write the average loss and accuracy values which is what we should exactly look at.
- Masking procedure is modified. This code is much easier to understand how the masking changes between train and test.
- Visualize mode is added which helps to plot the reconstructed sample images and visualize the network prediction performance.
- Saving and Loading of the trained models are improved.
- Data sets (MNIST and fashion-MNIST) get downloaded, automatically.
- This code Displays the real time results in the terminal (or command line).
All in all, the main features of the code are:
- The current version supports MNIST and Fashion-MNIST datasets.
- Run the code in Train, Test, and Visualize modes (explained at the bottom).
- The best validation and test accuracy for MNIST , and Fashion-MNIST are as follows (see details in the Results section):
Data set | Validation accuracy | Validation Loss | Test Accuracy | Test Loss |
---|---|---|---|---|
MNIST | % 99.44 | 0.0065 | % 99.34 | 0.0066 |
Fahion-MNIST | % 91.17 | 0.069 | % 90.57 | 0.071 |
- Python (2.7 preferably; also works fine with python 3)
- NumPy
- Tensorflow>=1.3
- Matplotlib (for saving images)
The code downloads the MNIST and Fashion-mnist datasets automatically if needed. MNIST is set as the default dataset.
Note: The default parameters of batch size is 128, and epoch 50. You may need to modify the config.py
file or use command line parameters to suit your case, e.g. set batch size to 64: python main.py --batch_size=64
Training the model displays and saves the training results (accuracy and loss) in a .csv file after your desired number of steps (100 steps by default) and validation results after each epoch.
- For training on MNIST data:
python main.py
- Loading the model and continue training:
python main.py --restore_training=True
- For training on Fashion-MNIST data:
python main.py --dataset=fashion-mnist
- For training with a different batch size:
python main.py --batch_size=100
- For training on MNIST data:
python main.py --mode=test
- For training on Fashion-MNIST data:
python main.py --dataset=fashion-mnist --mode=test
This mode is for running the trained model on a number of samples, get the predictions, and visualize (on 5 samples by default) the original and reconstructed images (also saved automatically in the /results/ folder).
- For MNIST data on 10 images:
python main.py --mode=visualize --n_samples=10
- For Fashion-MNIST data:
python main.py --mode=visualize
This mode is to check the vulnerability of the capsule network to adversarial examples; inputs that have been slightly changed by an attacker so as to trick a neural net classifier into making the wrong classification. Currently, only the untargeted BFGS method and it's iterative counterpart (commonly called Basic Iteration Method or BIM) are implemented. To run it on the trained model, use:
- FGSM mode:
python main.py --mode=adv_attack
- BIM mode:
python main.py --mode=adv_attack --max_iter=3
Training, validation and test results get saved separately for each dataset in .csv formats. By default, they get saved in the /results/ directory.
To view the results and summaries in Tensorboard, open the terminal in the main folder and type: tensorboard --logdir=results/mnist
for MNIST or tensorboard --logdir=results/fashion-mnist
, then open the generated link in your browser. It plots the average accuracy and total loss over training batches (over 100 batches by default), as well as the marginal and reconstruction loss separately. They are accessible through scalars tab on the top menu.
Fig2. Accuracy and loss curves in Tensorboard
You can also monitor the sample images and their reconstructed counterparts in realtime from the images tab.
Fig3. Sample original and recunstructed images in Tensorboard
After training, you can also run the code in visualize mode and get some of the results on sampled images saved in .png format. Example results for both MNIST and Fashion-MNIST data are as follows:
Fig4. Original and reconstructed images for MNIST and Fashion-MNIST data generated in visualize mode
To do soon (already added to the code)