Mai-CS / ML703-OptimalTransport

Imbalanced Classification of Electronic Health Records using Optimal Transport

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Imbalanced Classification of Electronic Health Records using Optimal Transport

python pytorch

Effective prediction of acute respiratory failure is critical for proactive healthcare management. Electronic health records (EHRs) provide rich patient information that can aid in prediction, but the imbalanced class distribution poses a challenge. In this study, we propose a novel approach that uses optimal transport (OT) for imbalanced classification of EHRs, combined with multimodality fusion. Specifically, we leverage multiple modalities of patient data, including clinical notes, demographics, vital signs, and laboratory results, to provide a comprehensive view of a patient's health status. Results show that our approach significantly improves prediction performance by leveraging the OT framework, which optimally matches the probability distributions of the imbalanced classes.

Contributions

This work is inspired by "Learning to Re-weight Examples with Optimal Transport for Imbalanced Classification" https://github.com/DandanGuo1993/reweight-imbalance-classification-with-OT

Requirements

Python 3.6
PyTorch 1.7.1
tqdm 4.19.9
torchvision 0.8.2
numpy 1.19.2

Results

Comparison of results between single modalities and multimodality Screenshot-2023-05-05-at-2-28-57-AM
Comparison of results between BCE loss only and BCE loss combined with OT loss using the fused model Screenshot-2023-05-05-at-2-24-31-AM

How to use

Note This repo is configured to run on CSCC @MBZUAI

  • Change data path
  • Change configuration setup in run.sh
  • Run OT_train.py using this command sh run.sh

About

Imbalanced Classification of Electronic Health Records using Optimal Transport


Languages

Language:Python 99.7%Language:Shell 0.3%