intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to use ipex.optimize for two models (model distillation) and one optimizer

ferreirafabio opened this issue · comments

Describe the issue

Hi,
I'm wondering how to use ipex.optimize(...) when I have two models, for example, teacher and student in model distillation but only one optimizer. Would calls like the following work? Note that teacher and student are both in train mode but teacher parameters are set to requires_grad=False, and the optimizer only operates over the student parameters.

teacher, optimizer = ipex.optimize(model=teacher, dtype=torch.float32, optimizer=optimizer)
student, optimizer = ipex.optimize(model=student, dtype=torch.float32, optimizer=optimizer)

I'm unaware of the intricacies this way of calling it may entail, and would like to get a waiver indicating that this is something okay to do. Thank you.

@ferreirafabio Hey Fabio, assuming the teacher is a large pretrained model, the optimizer would be required only for the student model. The following snippet depicts the usage of ipex,

teacher.eval()
student.train()
optimizer = ...

teacher = ipex.optimize(model=teacher, dtype=torch.float32)
student, optimizer = ipex.optimize(model=student, dtype=torch.float32, optimizer=optimizer)
...iterate over train data...
    with torch.no_grad():
        teacher_probs = softmax(teacher(train_data))
    
    student_probs = softmax(student(train_data))
    ... compute combined loss ...
    ..backprop..
    ..optimizer update..

Other things to consider:

  1. for teacher model you could also apply torch jit trace and freeze for added eval perf gains, refer ipex inference code samples
  2. using torch.bfloat16 dtype on latest Intel Xeon CPU and GPU's provides better performance

@vishnumadhu365 thank you for your reply and code example. I tried that but get a

AssertionError: The optimizer should be given for training mode

since the teacher is still in train mode since I do not execute .eval() mode on it. AFAIK, I cannot put it in eval mode because I do still want the Batch Norm statistics from the training mode. The precise application I'm using is training DINO. see here for the exact usage of teacher/student:

https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/main_dino.py#L210

What can be done in such a case?