pinglmlcv / MuHDi

Official PyTorch implementation of "Multi-Head Distillation for Continual Unsupervised Domain Adaptation in Semantic Segmentation"

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Multi-Head Distillation for Continual Unsupervised Domain Adaptation in Semantic Segmentation

Paper

Multi-Head Distillation for Continual Unsupervised Domain Adaptation in Semantic Segmentation
Antoine Saporta, Arthur Douillard, Tuan-Hung Vu, Patrick Pérez, Matthieu Cord
IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 2022 Workshop on Continual Learning

If you find this code useful for your research, please cite our paper:

@inproceedings{saporta2022muhdi,
  title={Multi-Head Distillation for Continual Unsupervised Domain Adaptation in Semantic Segmentation},
  author={Saporta, Antoine and Douillard, Arthur and Vu, Tuan-Hung and P{\'e}rez, Patrick and Cord, Matthieu},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshop},
  year={2022}
}

Abstract

Unsupervised Domain Adaptation (UDA) is a transfer learning task which aims at training on an unlabeled target domain by leveraging a labeled source domain. Beyond the traditional scope of UDA with a single source domain and a single target domain, real-world perception systems face a variety of scenarios to handle, from varying lighting conditions to many cities around the world. In this context, UDAs with several domains increase the challenges with the addition of distribution shifts within the different target domains. This work focuses on a novel framework for learning UDA, continuous UDA, in which models operate on multiple target domains discovered sequentially, without access to previous target domains. We propose MuHDi, for Multi-Head Distillation, a method that solves the catastrophic forgetting problem, inherent in continual learning tasks. MuHDi performs distillation at multiple levels from the previous model as well as an auxiliary target-specialist segmentation head. We report both extensive ablation and experiments on challenging multi-target UDA semantic segmentation benchmarks to validate the proposed learning scheme and architecture.

Preparation

Pre-requisites

  • Python 3.7
  • Pytorch >= 0.4.1
  • CUDA 9.0 or higher

Installation

  1. Clone the repo:
$ git clone https://github.com/valeoai/MuHDi
$ cd MuHDi
  1. Install OpenCV if you don't already have it:
$ conda install -c menpo opencv
  1. Install this repository and the dependencies using pip:
$ pip install -e <root_dir>

With this, you can edit the MuHDi code on the fly and import function and classes of MuHDi in other project as well.

  1. Optional. To uninstall this package, run:
$ pip uninstall MuHDi

Datasets

By default, the datasets are put in <root_dir>/data. We use symlinks to hook the MuHDi codebase to the datasets. An alternative option is to explicitly specify the parameters DATA_DIRECTORY_SOURCE and DATA_DIRECTORY_TARGET in YML configuration files.

  • GTA5: Please follow the instructions here to download images and semantic segmentation annotations. The GTA5 dataset directory should have this basic structure:
<root_dir>/data/GTA5/                               % GTA dataset root
<root_dir>/data/GTA5/images/                        % GTA images
<root_dir>/data/GTA5/labels/                        % Semantic segmentation labels
...
  • Cityscapes: Please follow the instructions in Cityscape to download the images and ground-truths. The Cityscapes dataset directory should have this basic structure:
<root_dir>/data/cityscapes/                         % Cityscapes dataset root
<root_dir>/data/cityscapes/leftImg8bit              % Cityscapes images
<root_dir>/data/cityscapes/leftImg8bit/train
<root_dir>/data/cityscapes/leftImg8bit/val
<root_dir>/data/cityscapes/gtFine                   % Semantic segmentation labels
<root_dir>/data/cityscapes/gtFine/train
<root_dir>/data/cityscapes/gtFine/val
...
  • Mapillary: Please follow the instructions in Mapillary Vistas to download the images and validation ground-truths. The Mapillary Vistas dataset directory should have this basic structure:
<root_dir>/data/mapillary/                          % Mapillary dataset root
<root_dir>/data/mapillary/train                     % Mapillary train set
<root_dir>/data/mapillary/train/images
<root_dir>/data/mapillary/validation                % Mapillary validation set
<root_dir>/data/mapillary/validation/images
<root_dir>/data/mapillary/validation/labels
...
  • IDD: Please follow the instructions in IDD to download the images and validation ground-truths. The IDD Segmentation dataset directory should have this basic structure:
<root_dir>/data/IDD/                         % IDD dataset root
<root_dir>/data/IDD/leftImg8bit              % IDD images
<root_dir>/data/IDD/leftImg8bit/train
<root_dir>/data/IDD/leftImg8bit/val
<root_dir>/data/IDD/gtFine                   % Semantic segmentation labels
<root_dir>/data/IDD/gtFine/val
...

Pre-trained models

Pre-trained models can be downloaded here and put in <root_dir>/pretrained_models

Running the code

For evaluation, execute:

$ cd <root_dir>/muhdi/scripts
$ python test.py --cfg ./configs/gta2cityscapes_advent_pretrained_test.yml
$ python test.py --cfg ./configs/gta2cityscapes2idd_baseline_pretrained_test.yml
$ python test.py --cfg ./configs/gta2cityscapes2idd_muhdi_pretrained_test.yml
$ python test.py --cfg ./configs/gta2cityscapes2idd2mapillary_baseline_pretrained_test.yml
$ python test.py --cfg ./configs/gta2cityscapes2idd2mapillary_muhdi_pretrained_test.yml

Training

For the experiments done in the paper, we used pytorch 1.3.1 and CUDA 10.0. To ensure reproduction, the random seed has been fixed in the code. Still, you may need to train a few times to reach the comparable performance.

By default, logs and snapshots are stored in <root_dir>/experiments with this structure:

<root_dir>/experiments/logs
<root_dir>/experiments/snapshots

To train the GTA5 -> Cityscapes AdvEnt model from the ImageNet pretrained ResNet:

$ cd <root_dir>/muhdi/scripts
$ python train.py --cfg ./configs/gta2cityscapes_advent.yml

To train the GTA5 -> Cityscapes -> IDD baseline model from the GTA5 -> Cityscapes pretrained AdvEnt:

$ cd <root_dir>/muhdi/scripts
$ python train.py --cfg ./configs/gta2cityscapes2idd_baseline.yml

To train the GTA5 -> Cityscapes -> IDD MuHDi model from the GTA5 -> Cityscapes pretrained AdvEnt:

$ cd <root_dir>/muhdi/scripts
$ python train.py --cfg ./configs/gta2cityscapes2idd_muhdi.yml

To train the GTA5 -> Cityscapes -> IDD -> Mapillary baseline model from the GTA5 -> Cityscapes pretrained baseline model:

$ cd <root_dir>/muhdi/scripts
$ python train.py --cfg ./configs/gta2cityscapes2idd2mapillary_baseline.yml

To train the GTA5 -> Cityscapes -> IDD -> Mapillary MuHDi model from the GTA5 -> Cityscapes pretrained MuHDi model:

$ cd <root_dir>/muhdi/scripts
$ python train.py --cfg ./configs/gta2cityscapes2idd2mapillary_muhdi.yml

Testing

To test the GTA5 -> Cityscapes AdvEnt model on Cityscapes:

$ cd <root_dir>/muhdi/scripts
$ python test.py --cfg ./configs/gta2cityscapes_advent.yml

To test the GTA5 -> Cityscapes -> IDD baseline model on IDD:

$ cd <root_dir>/muhdi/scripts
$ python test.py --cfg ./configs/gta2cityscapes2idd_baseline.yml

To test the GTA5 -> Cityscapes -> IDD MuHDi model on IDD:

$ cd <root_dir>/muhdi/scripts
$ python test.py --cfg ./configs/gta2cityscapes2idd_muhdi.yml

To test the GTA5 -> Cityscapes -> IDD -> Mapillary baseline model on Mapillary:

$ cd <root_dir>/muhdi/scripts
$ python test.py --cfg ./configs/gta2cityscapes2idd2mapillary_baseline.yml

To test the GTA5 -> Cityscapes -> IDD -> Mapillary MuHDi model on Mapillary:

$ cd <root_dir>/muhdi/scripts
$ python test.py --cfg ./configs/gta2cityscapes2idd2mapillary_muhdi.yml

Acknowledgements

This codebase is heavily borrowed from ADVENT and MTAF, as well as PLOP.

License

MuHDi is released under the Apache 2.0 license.

About

Official PyTorch implementation of "Multi-Head Distillation for Continual Unsupervised Domain Adaptation in Semantic Segmentation"

License:Other


Languages

Language:Python 99.5%Language:Dockerfile 0.5%