ahirsharan / MetaSegNet

Differentiable Meta-learning Model for Few-shot Semantic Segmentation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Note: In Development

Differentiable Meta-learning Model for Few-shot Semantic Segmentation (MetaSegNet)

MIT License PyTorch

Requirements

PyTorch and Torchvision needs to be installed before running the scripts, together with PIL for data-preprocessing and tqdm for showing the training progress.

To run this repository, kindly install python 3.7 and PyTorch 1.5.0 with Anaconda.

You may download Anaconda and read the installation instruction on their official website: https://www.anaconda.com/download/

Create a new environment and install PyTorch and torchvision on it:

conda create --name mseg python=3.7
conda activate mseg
conda install pytorch=1.5.0 
conda install torchvision -c pytorch

Clone this repository:

git clone https://github.com/ahirsharan/MetaSegNet.git

Code Structure

The code structure is based on MTL-template and Pytorch-Segmentation.

.
├── Datasets
    |
    ├── COCOAug     
    ├── Pascal5Aug
    ├── FSS1000Aug   
    |  
├── MetaSegNet
    |
    ├── FewShotPreprocessing.py     # utility to organise the Few-shot data into train and novel
    ├── cocogen.py                  # utility to organise the Few-shot data into train and novel after generating masks
    ├── augment.py                  # For generic data Augmentation 
    |
    |  
    ├── dataloader              
    |   ├── dataset_loader.py       # data loader for pre datasets
    |   └── samplers.py             # samplers for meta task dataset(Few-Shot) 
    |
    |
    ├── models                      
    |   ├── mtl.py                  # meta-transfer class
    |   └── metasegnet.py           # Resnet-9 class
    |
    ├── trainer                     
    |   ├── meta.py                  # meta-train trainer class
    |   
    |
    ├── utils                       
    |   ├── gpu_tools.py            # GPU tool functions
    |   ├── metrics.py              # Metrics functions
    |   ├── losses.py               # Loss functions
    |   ├── lovasz_losses.py        # Lovasz Loss function
    |   └── misc.py                 # miscellaneous tool functions
    |
    ├── main.py                     # the python file with main function and parameter settings
    └── run_meta.py                 # the script to run meta-train and meta-test phases

Running Experiments

Run meta-train and meta-test phase:

python run_meta.py

The test predictions and logs(models) will be stored in the same root directory under resultsx and logsx where x can be changed in trainer/meta.py . The tensorboardX log for loss and mIoU would be stored in runs in the MetaSegNet directory.

Hyperparameters and Options

Hyperparameters and options in main.py.

  • model_type The network architecture
  • mtype The ablation study argument for choosing MetaSegNet, MetaSegNet-NG and MetaSegConv
  • valdata The ablation study argument for choosing validation set also
  • dataset Meta dataset
  • phase train or test
  • seed Manual seed for PyTorch, "0" means using random seed
  • gpu GPU id
  • dataset_dir Directory for the images
  • max_epoch Epoch number for meta-train phase
  • num_batch The number for different tasks used for meta-train
  • way Way number, how many classes in a task(Background excluded)
  • train_query Shots: The number of training samples for each class in a task
  • test_query The number of test samples for each class in a task
  • meta_lr Learning rate for embedding model
  • base_lr Learning rate for the inner loop
  • update_step The number of updates for the inner loop
  • step_size The number of epochs to reduce the meta learning rates
  • gamma Gamma for the meta-train learning rate decay
  • init_weights The pretained weights for meta-train phase
  • meta_label Additional label for meta-train

About

Differentiable Meta-learning Model for Few-shot Semantic Segmentation


Languages

Language:Python 100.0%