tooth2 / Automatic-Image-Captioning

A Pytorch implementation of the CNN+RNN architecture on the MS-COCO dataset

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

A Pytorch implementation of the CNN+RNN architecture on the MS-COCO dataset MS-COCO


This project is about combining CNN and RNN networks to build a deep learning model that produces captions given an input image. Image captioning requires a complex deep learning model with two components:

  1. a CNN that transforms an input image into a set of features, and
  2. an RNN that turns those features into rich, descriptive language.

Code structure

  • 0_Dataset.ipynb
  • 1_Preliminaries.ipynb
  • 2_Training.ipynb
  • 3_Inference.ipynb


recurrent neural networks learn from ordered sequences of data.

  • use pre-trained (VGG-19) model for object detection and classification
  • combine pre-trained CNNs and RNNs to build a complex image captioning
  • Implement an LSTM for sequential text (image caption) generation.
  • Train a model to predict captions and understand a visual scene.


*get_loader function in

CNN encoder


  • Transform to Tensor: pre-process the test images

RNN decoder


Training: Hyperparameter Tunning

Model Parameter

  • vocab_threshold - the minimum word count threshold. Note that a larger threshold will result in a smaller vocabulary, whereas a smaller threshold will include rarer words and result in a larger vocabulary.
  • vocab_from_file - a Boolean that decides whether to load the vocabulary from file.
  • embed_size - the dimensionality of the image and word embeddings.
  • hidden_size - the number of features in the hidden state of the RNN decoder.

Training Parameter

  • num_epochs - the number of epochs to train the model
  • learn_rate
  • batch_size - the batch size of each training batch. It is the number of image-caption pairs used to amend the model weights in each training step.



A Pytorch implementation of the CNN+RNN architecture on the MS-COCO dataset


Language:Jupyter Notebook 99.0%Language:Python 1.0%