shahsohil / sunets

PyTorch Implementation of Stacked U-Nets (SUNets)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Stacked U-Nets (SUNets)

Introduction

This is a PyTorch implementation for training classification and semantic segmentation task using Stacked U-Nets models presented in the following paper (paper):

Sohil Shah, Pallabi Ghosh, Larry S. Davis and Tom Goldstein. Stacked U-Nets:A No-Frills Approach to Natural Image Segmentation.

If you use this code in your research, please cite our paper.

@article{shah2018sunets,
	author    = {Sohil Atul Shah and Pallabi Ghosh and Larry S. Davis and Tom Goldstein},
	title     = {Stacked U-Nets:A No-Frills Approach to Natural Image Segmentation},
	journal   = {arXiv:1804.10343},
	year      = {2018},
}

The source code and dataset are published under the MIT license. See LICENSE for details. In general, you can use the code for any purpose with proper attribution. If you do something interesting with the code, we'll be happy to know. Feel free to contact us.

Requirements

Data

  • Dataset(s) can be downloaded using the list of URLs provided here.
  • Extract the zip / tar and modify the path appropriately in config.json

Usage

ImageNet Classification

To train the model :

python train_imagenet.py [-h] [--arch ARCH] [-j N] [--epochs N]
                         [--start-epoch N] [-b N] [--lr LR] [--momentum M]
                         [--weight-decay W] [--print-freq N] [--resume PATH]
                         [-e] [--pretrained] [--world-size WORLD_SIZE]
                         [--dist-url DIST_URL] [--dist-backend DIST_BACKEND]
                         [--id ID] [--tensorboard] [--manualSeed MANUALSEED]
                         DIR

  DIR                   path to dataset
  --arch, -a            model architecture: alexnet | densenet121 |
                        densenet161 | densenet169 | densenet201 | inception_v3
                        | resnet101 | resnet152 | resnet18 | resnet34 |
                        resnet50 | squeezenet1_0 | squeezenet1_1 | vgg11 |
                        vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19
                        | vgg19_bn | sunet64 | sunet128 | sunet7128 | 
                        (default: sunet7128)
  -j, --workers         number of data loading workers (default: 8)
  --epochs              number of total epochs to run
  --start-epoch         manual epoch number (useful on restarts)
  -b, --batch-size      mini-batch size (default: 256)
  --lr                  initial learning rate
  --momentum            momentum
  --wd                  weight decay (default: 5e-4)
  --print-freq, -p      print frequency (default: 10)
  --resume              path to latest checkpoint (default: none)
  -e, --evaluate        evaluate model on validation set
  --pretrained          use pre-trained model
  --id                  identifying number
  --tensorboard         Log progress to TensorBoard
  --manualSeed          manual seed                        

For example, one can start training on ImageNet data using

python train_imagenet.py /path/to/imagenet/ -a sunet7128 -b 256 --resume /path/to/checkpoint/ --manualSeed 0 --id $JOBID --tensorboard --lr 0.01 --epochs 100

To view the results:

Simply run

tensorboard --logdir logs/

Semantic Segmentation

To train the model :

python train_seg.py [-h] [--arch [ARCH]] [--model_path MODEL_PATH]
                    [--dataset [DATASET]] [--img_rows [IMG_ROWS]]
                    [--img_cols [IMG_COLS]] [--n_epoch [N_EPOCH]]
                    [--batch_size [BATCH_SIZE]] [--l_rate [L_RATE]]
                    [--manualSeed MANUALSEED] [--iter_size ITER_SIZE]
                    [--log_size LOG_SIZE] [--momentum [MOMENTUM]] [--wd [WD]]
                    [--optim [OPTIM]] [--ost [OST]] [--freeze] [--restore]
                    [--split [SPLIT]]

  --arch                Architecture to use ['sunet64, sunet128, sunet7128 etc']                        
  --model_path          Path to the saved model                        
  --dataset             Dataset to use ['sbd, coco, cityscapes etc']
  --img_rows            Height of the input image                        
  --img_cols            Width of the input image                        
  --n_epoch             # of the epochs
  --batch_size          Batch Size                        
  --l_rate              Learning Rate
  --manualSeed          manual seed                        
  --iter_size           number of batches per weight updates                        
  --log_size            iteration period of logging segmented images
  --momentum            Momentum for SGD                        
  --wd                  Weight decay
  --optim               Optimizer to use ['SGD, Nesterov etc']
  --ost                 Output stride to use ['32, 16, 8 etc']
  --freeze              Freeze BN params
  --restore             Restore Optimizer params
  --split               Sets to use ['train_aug, train, trainvalrare, trainval_aug, trainval etc']                        

For example, one can start fine-tuning on pascal VOC2012 data using

python train_seg.py --arch sunet7128 --dataset sbd --batch_size 22 --iter_size 1 --n_epoch 90 --l_rate 0.0002 --momentum 0.95 --wd 1e-4 --optim SGD --img_rows 512 --img_cols 512 --ost 16

To validate the model at multiple scales:

python test_multiscale.py [-h] [--arch [ARCH]] [--model_path [MODEL_PATH]]
                          [--dataset [DATASET]] [--img_rows [IMG_ROWS]]
                          [--img_cols [IMG_COLS]]

  --arch                Architecture to use ['sunet64, sunet128, sunet7128 etc']                        
  --model_path          Path to the saved model               
  --dataset             Dataset to use ['sbd, cityscapes etc']
  --img_rows            Height of the Crop size             
  --img_cols            Width of the Crop size
  --ost                 Output stride to use ['32, 16, 8 etc']

For example, one can validate on pascal VOC2012 validation data using

python test_multiscale.py --arch sunet7128 --dataset sbd --model_path /path/to/checkpoint --img_rows 512 --img_cols 512 --ost 16

To evaluate the model on custom images(s):

python evaluate_pascal.py [-h] [--arch [ARCH]] [--model_path [MODEL_PATH]]
                          [--dataset [DATASET]] [--img_rows [IMG_ROWS]]
                          [--img_cols [IMG_COLS]] [--img_path [IMG_PATH]]
                          [--out_path [OUT_PATH]] [--coco [COCO]]
                          [--split SPLIT]

  --img_path            Path of the input image             
  --out_path            Path of the output segmap. Arranged according to PASCAL server requirements.             
  --coco                Trained with external data (coco) ?
  --split               val or test split
  --ost                 Output stride to use ['32, 16, 8 etc']

For example, one can evaluate on pascal VOC2012 test data using

python evaluate_pascal.py --arch sunet7128 --dataset sbd --model_path /path/to/checkpoint --img_rows 512 --img_cols 512 --split val --img_path /path/to/images --out_path /path/to/output_folder --ost 16

To view the results:

Launch visdom by running (in a separate terminal window) and run display.py.

python -m visdom.server
python display.py [--images]

The 'images' option will additionally display few validation images.

Acknowledgements

Parts of the code are inspired by the PyTorch implementation of semantic segmentation models by @meetshah1995 and @zijundeng.

About

PyTorch Implementation of Stacked U-Nets (SUNets)

License:MIT License


Languages

Language:Python 100.0%