hustzxd / EfficientPyTorch

A PyTorch Framework for Efficient Pruning and Quantization for specialized accelerators.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Optimizer BUG

hustzxd opened this issue · comments

Please define optimizer after Conv2D replacement.

# prune and quantize the model
wrapper.replace_conv_recursively(model, 'Conv2dSQ', nbits_a=args.qa, nbits_w=args.qw,
                                             sparsity=args.sparsity,
                                             total_iter=args.batch_num * args.epochs, INS=args.INS, beta=args.beta)
# define loss function (criterion) and optimizer after process model
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
optimizer = get_optimizer(model, args)