swift-n-brutal / illuminant_estimation

Deep Specialized Network for Illuminant Estimation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Illuminant Estimation

This project implements the illuminant estimation method which is presented in the paper "Deep Specialized Network for Illuminant Estimation" Project. The implementation is based on Python and TensorFlow.

Prerequisites

  • python=3.6
  • tensorflow=1.14
  • pyzmq=19.0.1 (optional)

Training from scratch

We take the training procedure on Gehler-Shi dataset as example. Please first follow the instructions in data to preprocess the data. Then run the following commands to train three models, as the performance should be evaluated by 3-fold cross validation.

CUDA_VISIBLE_DEVICES=0 python solver.py --gs-has-loc --gs-test-set 0 &
CUDA_VISIBLE_DEVICES=1 python solver.py --gs-has-loc --gs-test-set 1 &
CUDA_VISIBLE_DEVICES=2 python solver.py --gs-has-loc --gs-test-set 2 &

NOTE: ZeroMQ is recommended for efficient training. The training for each model takes roughly 12 hours on a single GeForce GTX TITAN X gpu.

If default parameters are used during training, the model parameters will be stored in models finally and the file names look like:

--- gs568-0_bs128_lr0.02
 |   |- hypnet_4000000.npz
 |   |- selnet_4000000.npz
 |
 |- gs568-1_bs128_lr0.02
 |   |- hypnet_4000000.npz
 |   |- selnet_4000000.npz
 |
 |- gs568-2_bs128_lr0.02
     |- hypnet_4000000.npz
     |- selnet_4000000.npz

Test

Then run the following commands to test on the three sets:

CUDA_VISIBLE_DEVICES=0 python solver.py --gs-has-loc --gs-test-set 0 --test-only &
CUDA_VISIBLE_DEVICES=1 python solver.py --gs-has-loc --gs-test-set 1 --test-only &
CUDA_VISIBLE_DEVICES=2 python solver.py --gs-has-loc --gs-test-set 2 --test-only &

Pre-trained models

Pre-trained models can be downloaded from the following links. Please unzip the files inside models.

Link Description
OneDrive Trained for 3-fold cross validation
OneDrive Trained on all images

The estimated illuminants for local patches of each image will be stored in preds. Finally run the following command to get global predictions for each image:

python test_preds.py --weighted-median

About

Deep Specialized Network for Illuminant Estimation


Languages

Language:Python 97.7%Language:MATLAB 2.3%