wanjinchang / pytorch-segmentation

PyTorch implementation for semantic segmentation (DeepLabV3+, UNet, etc.)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PytorchSegmentation

This repository implements general network for semantic segmentation.
You can train various networks like DeepLabV3+, PSPNet, UNet, etc., just by writing the config file.

DeepLabV3+

Pretrained model

You can run pretrained DeepLabv3+ converted from official tensorflow model.
Currently I checked that xception65_cityscapes_trainfine can be converted.

$ mkdir tf_model
$ cd tf_model
$ wget http://download.tensorflow.org/models/deeplabv3_cityscapes_train_2018_02_06.tar.gz
$ tar -xvf deeplabv3_cityscapes_train_2018_02_06.tar.gz
$ cd ../src
$ mkdir ../model
$ python convert.py ../tf_model/deeplabv3_cityscapes_train/model.ckpt 19 ../model/cityscapes_deeplab_v3_plus/model.pth

Then you can test the performance of trained network.

$ python eval.py

How to train

In order to train model, you have only to setup config file.
For example, write config file as below and save it as config/pascal_unet_res18_scse.yaml.

Net:
  enc_type: 'resnet18'
  dec_type: 'unet_scse'
  num_filters: 8
  pretrained: True

Data:
  dataset: 'pascal'
  preprocess: 'imagenet'
  target_size: (256, 256)

Train:
  max_epoch: 20
  batch_size: 2
  resume: False
  start_epoch: 0

Loss:
  weight:
  size_average: True
  batch_average: True

Optimizer:
  mode: 'adam'
  base_lr: 0.001
  t_max: 10

Then you can train this model by:

$ python train.py ../config/pascal_unet_res18_scse.yaml

Dataset

Directory tree

.
├── config
├── data
│   ├── cityscapes
│   │   ├── gtFine
│   │   └── leftImg8bit
│   └── pascal_voc_2012
│        └── VOCdevkit
│            └── VOC2012
│                ├── JPEGImages
│                ├── SegmentationClass
│                └── SegmentationClassAug
├── logs
├── model
└── src
    ├── dataset
    ├── logger
    ├── losses
    │   ├── binary
    │   └── multi
    ├── models
    │   └── inplace_abn
    └── utils

Environments

  • OS: Ubuntu18.04
  • python: 3.6.4
  • pytorch: 0.4.1

Reference

Encoder

Decoder

SCSE

IBN

OC

PSP

ASPP

Ohter

inplaceABN

About

PyTorch implementation for semantic segmentation (DeepLabV3+, UNet, etc.)


Languages

Language:Python 83.1%Language:Cuda 8.4%Language:C++ 8.4%Language:Shell 0.1%