Deep bilateral network based on the paper here. This network predicts local transformations from a low resolution input and applies them to a high resolution input in an adaptive way using a bilateral slicing layer. It can be thought of as a network that learns a black-box image filter. The main advantages of this network are:
- Speed. Very shallow network, and bulk of computation done at low resolution.
- High resolution output.
- Edge-preserving since it learns transforms rather than output spaces directly. As such, it is less susceptible to artifacts.
- Python >= 3.6
- PyTorch >= 0.4
- tensorboardX
- openCV (for generating videos)
- scikit-video (for generating videos)
- CUDA (preferably, but not necessary)
The code is organized primarily into the following files and directories:
configuration.py
: Location to specify arguments with which to train/evaluate/run. Global variables in this file are modified instead of specifying script command line parameters.train.py
: Script used to train on a dataset. Acquires all parameters fromconfiguration.py
.eval.py
: Script used to evaluate on a dataset. Acquires all parameters fromconfiguration.py
.generate_videos.py
: Script used to run a trained model on a directory of videos and get the output. Acquires parameters fromconfiguration.py
, and can specify the input/output video directories directly within the script.models/
: Contains the mainDeepBilateralNetCurves
model and its subclasses that use slightly different modules.datasets/
: Contains the input pipeline classes.BaseDataset.py
defines a dataset that works for general purposes, but other subclasses can be created using it as a model. If you create a new dataset class, make sure to add the corresponding code needed to specify it inconfiguration.py
.saved_runs/
: Contains the saved models and output logs from training.bilateral_slice_op/
: Contains the C++ code for the custom bilateral slice layer as well as the code needed to build it.
This model uses a custom layer with C++ implementation that must be built prior to training/running the model.
cd bilateral_slice_op
python setup.py install
- Create a dataset. If using the
datasets/BaseDataset.py
loader, seedata/debug
for an example of how to structure the data. Note that you can quickly run the model on this dataset just to make sure that everything is working and that you have successfully built the bilateral slice layer. - Specify the model and training parameters in
configuration.py
by modifying the global variables defined near the top of the file. python train.py [run_name]
- For viewing the learning curve and all evaluation results using TensorBoard, run
tensorboard --logdir=saved_model/[run_name]
and open the port to see the evaluation results. If don't know how to use tensorboard, can check Tensorboard.
- Specify the model parameters in
configuration.py
by modifying the global variables defined near the top of the file. python eval.py [run_name]
- Specify the model parameters in
configuration.py
by modifying the global variables defined near the top of the file. Note that the variablepretrained_path
must specify the path to a.pth
model file. - Specify the
input_dir
andoutput_dir
global variables at the top of thegenerate_videos.py
file. python generate_videos.py
- Pull the following conversion repository (proprietary).
- Copy your trained
model.pth
file todeep_bilateral_network/
in the conversion repository. - Modify model parameters as well as input and output shapes inside the
load_model()
function of thedeep_bilateral_network/convert.py
file of the conversion repository. - Follow the conversion repository instructions for running the conversion script.
* Notes:
- The converted model does not include the final output following the bilateral slice layer at this time. This is because we have yet to write a custom CoreML layer. Instead, both the coefficients and guidemap are returned.
- The implementation for a linear combination of ReLU's applied when computing the guidemap is currently hacked together using a series of other layers. Writing a custom CoreML layer for this would probably improve performance.
todo: Add documentation for this conversion