- Clone the repo
- Intall using pip
pip install -e .
Look at https://github.com/samirsalman/distillai/blob/main/examples/example.py for more details.
from distill_ai.losses.distillation_loss import DistillationLoss
from distill_ai.trainers.distillation_trainer import KnowledgeDistillationTrainer
# init student and teacher models
student = Student()
teacher = Teacher()
# define loss functions
student_target_loss = nn.CrossEntropyLoss()
# define distillation loss
distillation_loss = DistillationLoss(alpha=0.25, temperature=1.0)
# define optimizer
optimizer = torch.optim.Adam
# define trainer
trainer = KnowledgeDistillationTrainer(
max_epochs=10,
# pytorch lightning kwargs
)
# train
trainer.fit(
# torch data loaders
train_dataloader=train_dataloader,
val_dataloader=test_dataloader,
# models
teacher_model=teacher,
student_model=student,
# loss functions
student_target_loss=student_target_loss,
distillation_loss=distillation_loss,
# optimizer
optimizer=optimizer,
)
Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are greatly appreciated.
If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". Don't forget to give the project a star! Thanks again!
- Fork the Project
- Create your Feature Branch (
git checkout -b feature/AmazingFeature
) - Commit your Changes (
git commit -m 'Add some AmazingFeature'
) - Push to the Branch (
git push origin feature/AmazingFeature
) - Open a Pull Request