A simple tool to help to train a semantic segmentation network for a video or a set of images.
- label the first frame of the video use labelme [https://github.com/wkentaro/labelme]
- generate the database for trainning
- train a U-net-like network
- Nvidia Graphic cards for cuda trainning.
-
now work in windows 10
. However, it is easy to modified toand Ubuntu -
python >= 3.6
-
installing labelme -
pytorch with version >= 0.4.1
-
skimage, PIL, pydensecrf (installed with pip)
- Firstly, you should have your first frame of video labeled. This step could be done by anyway you want. The label result should be stored in a npz file. The data formation should be a numpy array with size (W, H, C) (weight, height of the image, and channels for your labels), and this label should only contain 0 and 1. If you don't know how to generate this type of file, you could use ./data_gen/utils/rgb2multilayers to help you convert a colorful label (Figure 1) into a multilayer formation label. The color list is :[black, green, red, orange] for rightnow, which means that you can do classification of equal or less than 4.
Figure 1: rgb label example
- To predict a picture: example
python predict.py --model 'path/to/model.pth' --input 'path/to/image/to/predict' --viz
- To predict pictures from a video (size same size)
python predict_batch.py --model 'path/to/model' --input 'path/to/images/' --output 'path/to/path/to/save/result'
- To train a net: example
-
python train.py -i 'path/to/image/' -m 'path/to/masks' -v 'checkpointsavepath' -l 0.1 -d 0.99 -e 30 -b 10
-
python train.py -i 'path/to/image/' -m 'path/to/masks' -v 'path/to/save/checkpoin'
-
the semantic segmentation code is modified from pytorch-unet
-
img_distortion part of database generation code is modified from [https://gist.github.com/erniejunior/601cdf56d2b424757de5]
- convert the label from labelme format to numpy.npz format
- complete the data generation code
- change the loss function in Pytorch-unet project from log loss function to weight loss function, and other modification
- make all the workflow more auto, use ./train_and_predict.py and ./data_gen/database_test.py
- make it work in Ubuntu