reai000 / thermal-uda-attention

Unsupervised RGB-to-Thermal Domain Adaptation via Multi-Domain Attention Network

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

thermal-uda-cls

Install

$ git clone https://github.com/ganlumomo/thermal-uda-cls.git
$ cd thermal-uda-cls
$ conda env create -f environment.yml
$ conda activate thermal-uda-cls

Dataset Preparation

Download

Process

$ python utils/prepare_dataset_mscoco_flir.py
$ python utils/prepare_dataset_mscoco_m3fd.py

Running

Training for MS-COCO to FLIR

(thermal-uda-cls) $ python core/main.py \
 --tgt_cat flir --n_classes 3 \
 --batch_size 32 --epochs 15 \
 --device cuda:0 --logdir outputs/flir

Training for MS-COCO to M3FD

(thermal-uda-cls) $ python core/main.py \
 --tgt_cat m3fd --n_classes 6 \
 --batch_size 32 --epochs 30 \
 --device cuda:0 --logdir outputs/m3fd

optional:

  • --self_train: self training using pseudo labels
  • --wandb: enable wandb logging

Test

(thermal-uda-cls) $ python core/test.py \
 --tgt_cat m3fd --n_classes 6 \
 --trained outputs/m3fd/best_model.pt \
 --device cuda:0 --logdir outputs/m3fd

optional:

  • --d_trained outputs/m3fd/best_model_d.pt: pseudo label generation
  • --tsne: enable t-SNE visualization

Acknowledgement

This repo is based on:

About

Unsupervised RGB-to-Thermal Domain Adaptation via Multi-Domain Attention Network


Languages

Language:Python 100.0%