xinshi-chen / l2stop

Learning To Stop While Learning To Predict

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Learning To Stop While Learning To Predict (ICML 2020)

If you found this library useful in your research, please consider citing

@inproceedings{chen2020learning,
  title={Learning to stop while learning to predict},
  author={Chen, Xinshi and Dai, Hanjun and Li, Yu and Gao, Xin and Song, Le},
  booktitle={International Conference on Machine Learning},
  pages={1520--1530},
  year={2020},
  organization={PMLR}
}
@article{chen2020learning,
  title={Learning to Stop While Learning to Predict},
  author={Chen, Xinshi and Dai, Hanjun and Li, Yu and Gao, Xin and Song, Le},
  journal={arXiv preprint arXiv:2006.05082},
  year={2020}
}

Reproduce Experiments In Sec 5.1. Sparse Recovery

Install the module

Please navigate to the root of this repository, and run the following command to install the lista_stop module.

pip install -e .

Run traditional algorithms: ISTA and FISTA

Navigate to the /lista_stop/baselines folder and run the following commands to reproduce results of ISTA and FISTA, respectively.

sh run_ista.sh

sh run_fista.sh

Run the baseline model: LISTA

Navigate to the /lista_stop/experiments folder and run the following command.

sh run_lista.sh

Run our method: LISTA-stop

The training process of LISTA-stop has two stages. For stage 1 training, navigate to the /lista_stop/experiments folder and run the following command.

run_lista_stop_stage1.sh

For stage 2 training, run the following command.

run_lista_stop_stage2.sh

Reproduce Experiments In Sec 5.2. MAML

Please navigate to maml_stop/ folder for the details.

Reproduce Experiments In Sec 5.3. Image Denoising

Configure the environment

Please navigate to the section folder ./dncnn_stop. Then, using the following command, we can configure the environment for the denoise experiments. Please keep the environment activated for this section.

conda env create -f environment.yml
source activate dncnn_stop

Download the dataset

Please download the dataset and unzip the dataset in this folder.

Run our method

Please run the following command to check our method.

# pretrain the model
python -u train.py --model DnCNN --outf logs/dncnn_b_l20_all_train_n55 \
	--num_of_layers 20 --batchSize 256 --epoch 50

# fine-tuning with tao as 10
python -u train.py --model DnCNN_DS --outf logs/dncnn_b_ds_l20_all_train_tune_tao10 \
	--train_all True --batchSize 256 --lr 1e-4 --epoch 50 \
	--tao 10 --pretrain_path logs/dncnn_b_l20_all_train_n55/net.pth \
	--pretrain True

# policy training, this is in test phase
python train_stop_kl.py \
	--outf logs/dncnn_b_ds_l20_all_train_tune_tao10 --restart True -phase test

# joint training, this is in test phase
python train_stop_joint.py \
	--outf logs/dncnn_b_ds_l20_all_train_tune_tao10 --restart True -phase test

# Quantitative evaluation
for noise in 35 45 55 65 75; do
	python -u test.py --test_data Set68  --num_of_layers 20 \
		--logdir logs/dncnn_b_ds_l20_all_train_tune_tao10 --model DnCNN_DS \
		--test_noiseL ${noise}
done

# generate the denoised images
for noise in 45 65;do
	python -u test.py --test_data Set68  --num_of_layers 20 \
		--logdir logs/dncnn_b_ds_l20_all_train_tune_tao10 --model DnCNN_DS \
		--test_noiseL ${noise} --save_img True --img_folder ./out_imgs/dncnn_stop_${noise}
done

Reproduce Experiments In Sec 5.4. Image Recognition

Please navigate to the section folder ./sdn_stop.

Download the dataset

Download TinyImageNet from https://tiny-imagenet.herokuapp.com/, place it under data/ and use data.py - create_val_folder() to generate proper directory structure.

Run the experiment

# Train the classifiers
python train_networks.py

# Check the performance of sdn and l2stop. The policy network would not work at the current stage for this task.
python train_stop_kl.py

Credit and note

We build this part based on http://shallowdeep.network. We mainly changed the loss function in model_funcs.py.

About

Learning To Stop While Learning To Predict

License:MIT License


Languages

Language:Python 96.7%Language:Shell 3.3%