tianyic / only_train_once

OTOv1-v3, NeurIPS, ICLR, TMLR, DNN Training, Compression, Structured Pruning, Erasing Operators, CNN, Diffusion, LLM

Home Page:https://openreview.net/pdf?id=7ynoX1ojPMt

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to set epsilon param of dhspg?

gaoxueyi0532 opened this issue · comments

In tutorials, epsilon is setted to 0.95, but it is recommend to be in range [0.0, 0.05] from paper's experiments and theroy analysis, so confusing it is!
optimizer = oto.dhspg( variant='sgd', lr=0.1, target_group_sparsity=0.7, weight_decay=1e-4, start_pruning_steps=50 * len(trainloader), # start pruning after 50 epochs epsilon=0.95)

Bellow is my code,
opt = oto.dhspg( variant='sgd', lr=0.01, target_group_sparsity=0.3, weight_decay=1e-4, start_pruning_steps=100 * len(train_loader), # start pruning after 50 epochs epsilon=0.02)

which is reasonable? or both are reasonable?

This is good question. For theorem, we required epsilon in [0, 1). The value difference is due to the HSPG arxiv paper on 2020's implementation has a bit discrepancy. Please use the up-to-date version, i.e., this repo. Meanwhile, we typically use epsilon as 0.9, 0.95 for all our experiments.

I think you raised the question might be during applying DHSPG onto your model, the group sparsity does not produce as expected. If so, to mitigate the issue, please keep the below in mind.

  • Start pruning at initialization stage. You started pruning at 100 epochs. I am not sure how long your total training last for. During our varying practice, we start pruning right after a few number of epochs for warmup.
  • Adjust other hyper-parameter if needed. Your initial learning rate is 1e-2. In hyperparameters, we provided default settings for group sparsity exploration upon varying optimizers. Please increase lmbda, lmbda_amplify and hat_lmbda_coeff 10 times larger if you feel group sparsity does not produce that well.

All such hyperparamters can be set up in the

optimizer = oto.dhspg(
***
lmbda=,
lmbda_amplify=,
hat_lmbda_coeff=,
)

Hope the above help.