Segmentation models is python library with Neural Networks for Image Segmentation based on Keras (Tensorflow) framework.
The main features of this library are:
- High level API (just two lines to create NN)
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
- 25 available backbones for each architecture
- All backbones have pre-trained weights for faster and better convergence
Since the library is built on the Keras framework, created segmentaion model is just a Keras Model, which can be created as easy as:
from segmentation_models import Unet
model = Unet()
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model = Unet('resnet34', encoder_weights='imagenet')
Change number of output classes in the model:
model = Unet('resnet34', classes=3, activation='softmax')
Change input shape of the model:
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)
from segmentation_models import Unet
from segmentation_models.backbones import get_preprocessing
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score
BACKBONE = 'resnet34'
preprocess_input = get_preprocessing(BACKBONE)
# load your data
x_train, y_train, x_val, y_val = load_data(...)
# preprocess input
x_train = preprocess_input(x_train)
x_val = preprocess_input(x_val)
# define model
model = Unet(BACKBONE, encoder_weights='imagenet')
model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score])
# fit model
model.fit(
x=x_train,
y=y_train,
batch_size=16,
epochs=100,
validation_data=(x_val, y_val),
)
Same manimulations can be done with Linknet
, PSPNet
and FPN
. For more detailed information about models API and use cases Read the Docs.
Models
Unet | Linknet |
---|---|
PSPNet | FPN |
---|---|
Backbones
Type | Names |
---|---|
VGG | 'vgg16' 'vgg19' |
ResNet | 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' |
SE-ResNet | 'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152' |
ResNeXt | 'resnext50' 'resnext101' |
SE-ResNeXt | 'seresnext50' 'seresnext101' |
SENet154 | 'senet154' |
DenseNet | 'densenet121' 'densenet169' 'densenet201' |
Inception | 'inceptionv3' 'inceptionresnetv2' |
MobileNet | 'mobilenet' 'mobilenetv2' |
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (encoder_weights='imagenet'
).
Requirements
- Python 3.5+
- Keras >= 2.2.0
- Keras Application >= 1.0.7
- Image Classifiers == 0.2.0
- Tensorflow 1.9 (tested)
Pip package
$ pip install segmentation-models
Latest version
$ pip install git+https://github.com/qubvel/segmentation_models
Latest documentation is avaliable on Read the Docs
To see important changes between versions look at CHANGELOG.md
Project is distributed under MIT Licence.