Effective prediction of acute respiratory failure is critical for proactive healthcare management. Electronic health records (EHRs) provide rich patient information that can aid in prediction, but the imbalanced class distribution poses a challenge. In this study, we propose a novel approach that uses optimal transport (OT) for imbalanced classification of EHRs, combined with multimodality fusion. Specifically, we leverage multiple modalities of patient data, including clinical notes, demographics, vital signs, and laboratory results, to provide a comprehensive view of a patient's health status. Results show that our approach significantly improves prediction performance by leveraging the OT framework, which optimally matches the probability distributions of the imbalanced classes.
This work is inspired by "Learning to Re-weight Examples with Optimal Transport for Imbalanced Classification" https://github.com/DandanGuo1993/reweight-imbalance-classification-with-OT
Python 3.6
PyTorch 1.7.1
tqdm 4.19.9
torchvision 0.8.2
numpy 1.19.2
Comparison of results between single modalities and multimodality
Comparison of results between BCE loss only and BCE loss combined with OT loss using
the fused model
Note This repo is configured to run on CSCC @MBZUAI
- Change data path
- Change configuration setup in run.sh
- Run OT_train.py using this command
sh run.sh