som-shahlab / ehrshot-benchmark

A benchmark for few-shot evaluation of foundation models for electronic health records (EHRs)

Home Page:https://ehrshot.stanford.edu

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

๐Ÿ‘‚ ๐Ÿ’‰ EHRSHOT

A benchmark/dataset for few-shot evaluation of foundation models for electronic health records (EHRs). You can read the paper here.

Please note that the dataset + model are still being reviewed, and a download link will be provided once they are approved for public release.


Whereas most prior EHR benchmarks are limited to the ICU setting, EHRSHOT contains the full longitudinal health records of 6,739 patients from Stanford Medicine and a diverse set of 15 classification tasks tailored towards few-shot evaluation of pre-trained models.

๐Ÿ“– Table of Contents

  1. Pre-trained Foundation Model
  2. Dataset + Tasks
  3. Comparison to Prior Work
  4. Installation
  5. Usage
  6. Citation

Access: The model is on HuggingFace here and requires signing a research usage agreement.

We publish the model weights of a 141 million parameter clinical foundation model pre-trained on the deidentified structured EHR data of 2.57M patients from Stanford Medicine.

We are one of the first to fully release such a model for coded EHR data; in contrast, most prior models released for clinical data (e.g. GatorTron, ClinicalBERT) only work with unstructured text and cannot process the rich, structured data within an EHR.

We use Clinical Language-Model-Based Representations (CLMBR) as our model. CLMBR is an autoregressive model designed to predict the next medical code in a patient's timeline given previous codes. CLMBR employs causally masked local attention, ensuring forward-only flow of information which is vital for prediction tasks and is in contrast to BERT-based models which are bidirectional in nature. We utilize a transformer as our base model with 141 million trainable parameters and a next code prediction objective, providing minute-level EHR resolution rather than the day-level aggregation of the original model formulation.

Please Note: Dataset release is currently being reviewed and the download link will be updated once it is publicly available.

The EHRSHOT (version 1) dataset contains:

  • 6,739 patients
  • 41.6 million clinical events
  • 921,499 visits
  • 15 prediction tasks

Each patient consists of an ordered timeline of clinical events taken from the structured data of their EHR (e.g. diagnoses, procedures, prescriptions, etc.).

Each task is a predictive classification task, and includes a canonical train/val/test split. The tasks are defined as follows:

Task Type Prediction Time Time Horizon
Long Length of Stay Binary 11:59pm on day of admission Admission duration
30-day Readmission Binary 11:59pm on day of discharge 30-days post discharge
ICU Transfer Binary 11:59pm on day of admission Admission duration
Thrombocytopenia 4-way Multiclass Immediately before result is recorded Next result
Hyperkalemia 4-way Multiclass Immediately before result is recorded Next result
Hypoglycemia 4-way Multiclass Immediately before result is recorded Next result
Hyponatremia 4-way Multiclass Immediately before result is recorded Next result
Anemia 4-way Multiclass Immediately before result is recorded Next result
Hypertension Binary 11:59pm on day of discharge 1 year post-discharge
Hyperlipidemia Binary 11:59pm on day of discharge 1 year post-discharge
Pancreatic Cancer Binary 11:59pm on day of discharge 1 year post-discharge
Celiac Binary 11:59pm on day of discharge 1 year post-discharge
Lupus Binary 11:59pm on day of discharge 1 year post-discharge
Acute MI Binary 11:59pm on day of discharge 1 year post-discharge
Chest X-Ray Findings 14-way Multilabel 24hrs before report is recorded Next report

Most prior benchmarks are (1) limited to the ICU setting and (2) not tailored towards few-shot evaluation of pre-trained models.

In contrast, EHRSHOT contains (1) the full breadth of longitudinal data that a health system would expect to have on the patients it treats and (2) a broad range of tasks designed to evaluate models' task adaptation and few-shot capabilities:

Benchmark Source EHR Properties Evaluation Reproducibility
Dataset ICU/ED Visits Non-ICU/ED Visits # of Patients # of Tasks Few Shot Dataset via DUA Preprocessing Code Model Weights
EHRSHOT Stanford Medicine โœ“ โœ“ 7k 15 โœ“ โœ“ โœ“ โœ“
MIMIC-Extract MIMIC-III โœ“ -- 34k 5 -- โœ“ โœ“ --
Purushotham 2018 MIMIC-III โœ“ -- 35k 3 -- โœ“ โœ“ --
Harutyunyan 2019 MIMIC-III โœ“ -- 33k 4 -- โœ“ โœ“ --
Gupta 2022 MIMIC-IV โœ“ * 257k 4 -- โœ“ โœ“ --
COP-E-CAT MIMIC-IV โœ“ * 257k 4 -- โœ“ โœ“ --
Xie 2022 MIMIC-IV โœ“ * 216k 3 -- โœ“ โœ“ --
eICU eICU โœ“ -- 73k 4 -- โœ“ โœ“ --
EHR PT MIMIC-III / eICU โœ“ -- 86k 11 โœ“ โœ“ โœ“ --
FIDDLE MIMIC-III / eICU โœ“ -- 157k 3 -- โœ“ โœ“ --
HiRID-ICU HiRID โœ“ -- 33k 6 -- โœ“ โœ“ --
Solares 2020 CPRD โœ“ โœ“ 4M 2 -- -- -- --

Please use the following steps to create an environment for running the EHRSHOT benchmark.

1): Install EHRSHOT

conda create -n EHRSHOT_ENV python=3.10 -y
conda activate EHRSHOT_ENV

git clone https://github.com/som-shahlab/ehrshot-benchmark.git
cd ehrshot-benchmark
pip install -r requirements.txt

2): Install FEMR

For our data preprocessing pipeline we use FEMR (Framework for Electronic Medical Records), a Python package for building deep learning models with EHR data.

You must also have CUDA/cuDNN installed (we recommend CUDA 11.8 and cuDNN 8.7.0)

Note that this currently only works on Linux machines.

pip install --upgrade "jax[cuda11_pip]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda install bazel=6 -y
pip install git+https://github.com/som-shahlab/femr.git@ehrshot_branch

# pip install femr==0.0.20

Download Private Assets

You will need to separately download several assets that we cannot redistribute publicly on Github.

This includes the dataset itself, the weights of the pre-trained foundation model we benchmark, and the Athena OHDSI Ontology.

A) Dataset & Foundation Model for EHRs

Please Note: Dataset + model release is currently being reviewed and the download link will be updated once it is publicly available.

Once this is downloaded, unzip it to get a folder called EHRSHOT_ASSETS/. Please move this folder to the root of this repo.

B) Athena OHDSI Ontology

Our pipeline requires the user to provide an ontology in order to map medical codes to their parents/children. We use the default Athena OHDSI Ontology for this.

Unfortunately, we cannot redistribute the Athena OHDSI Ontology ourselves, so you must separately download it by following these steps:

  1. Go to the Athena website at this link. You may need to create an account.
  2. Click the green "Download" button at the top right of the website
  3. Click the purple "Download Vocabularies" button below the green "Download" button
  4. Name the bundle "athena_download" and select 5.x version
  5. Scroll to the bottom of the list, and click the blue "Download" button
  6. It will take some time for the download to be ready. Please refresh the webpage here to check whether your download is ready. Once the download is ready, click "Download"
  7. After the download is complete, unzip the file and move all the files into the EHRSHOT_ASSETS/athena_download/ folder in your repo.

After downloading the Athena OHDSI Ontology, you will have to separately download the CPT subset of the ontology. You can follow the instructions in the readme.txt in your Athena download, or follow the steps below:

  1. Create a UMLS account here
  2. Get your UMLS API key here
  3. From the EHRSHOT_ASSETS/athena_download/ folder, run this command: bash cpt.sh <YOUR UMLS API KEY>

Your ontology will then be ready to go!

Folder Structure

Your final folder structure should look like this:

  • ehrshot-benchmark/
    • EHRSHOT_ASSETS/
      • data/
        • We provide this asset, which contains deidentified EHR data as CSVs.
      • benchmark/
        • We provide this asset, which contains labels and few-shot samples for all our tasks.
      • models
        • We provide this asset, which contains our pretrained foundation model for EHRs.
      • athena_download/
        • You will need to download and put the Athena OHDSI Ontology inside this folder. Please follow the instructions above to download it.
    • ehrshot/
      • We provide the scripts to run the benchmark here

To execute the entire benchmark end-to-end, please run:

python3 run_all.py

You can also run each of the steps individually by directly calling their corresponding Python/Bash files in the ehrshot/ folder. Note that depending on your system, you may need to change the Bash scripts.

Here is a breakdown of what each step in the pipeline does:

1): Convert the EHRSHOT CSV files into a format that the FEMR library can process.

python3 1_create_femr_database.py \
    --path_to_input ../EHRSHOT_ASSETS/data \
    --path_to_target ../EHRSHOT_ASSETS/femr \
    --athena_download ../EHRSHOT_ASSETS/athena_download \
    --num_threads 10

Alternatively, you can also run

sbatch 1_create_femr_database_slurm.sh

Please make sure you change the Bash script according to your system. You may not be able to run it as a slurm job.

2): Apply the labeling functions defined in FEMR to our dataset to generate labels for our benchmark tasks.

Note that as part of our dataset release, we also include these labels in a CSV. Thus, you should skip to the label generation part of the script by setting the --is_skip_label flag.

python3 2_generate_labels_and_features.py \
    --path_to_database ../EHRSHOT_ASSETS/femr/extract \
    --path_to_output_dir ../EHRSHOT_ASSETS/benchmark \
    --path_to_chexpert_csv ../EHRSHOT_ASSETS/benchmark/chexpert/chexpert_labeled_radiology_notes.csv \
    --labeling_function guo_los \
    --is_skip_label \
    --num_threads 10

In case you want to regenerate your labels, you can run the above command without the --is_skip_label flag.

The above command runs it only for guo_los (Long Length of Stay) labeling function. You will need to individually run this script for each of the 15 tasks. Alternatively, you can run the Bash script shown below to iterate through every task automatically.

sbatch 2_generate_labels_and_features_slurm.sh

3): Generate a CLMBR representation for each patient for each label. Below is an example of how to run it for one task (guo_los).

Note that this job requires a GPU.

python3 3_generate_clmbr_representations.py \
    --path_to_clmbr_data ../EHRSHOT_ASSETS/models/clmbr_model \
    --path_to_database ../EHRSHOT_ASSETS/femr/extract \
    --path_to_labeled_featurized_data ../EHRSHOT_ASSETS/benchmark \
    --path_to_save ../EHRSHOT_ASSETS/clmbr_reps \
    --labeling_function guo_los

To run it for all tasks automatically, run the following Bash script:

sbatch 3_generate_clmbr_representations_slurm.sh

4): Generate our k-shots for few-shot evaluation.

Note that we provide the exact k-shots used in our paper with our data release. Please do not run this script if you want to use the k-shots we used in our paper.

python3 4_generate_shot.py \
    --path_to_data ../EHRSHOT_ASSETS \
    --labeling_function guo_los \
    --num_replicates 1 \
    --path_to_save ../EHRSHOT_ASSETS/benchmark \
    --shot_strat few

To run it for all tasks automatically, run the following Bash script:

sbatch 4_generate_shot_slurm.sh

5): Train our baseline models and generate performance metrics.

python3 5_eval.py \
    --path_to_data ../EHRSHOT_ASSETS \
    --labeling_function guo_los \
    --num_replicates 5 \
    --model_head logistic \
    --is_tune_hyperparams \
    --path_to_save ../EHRSHOT_ASSETS/output \
    --shot_strat few

To run it for all tasks automatically, run the following Bash script:

sbatch 5_eval_slurm.sh

6): Generate the plots we included in our paper.

python3 6_make_figures.py \
    --path_to_eval ../EHRSHOT_ASSETS/output \
    --path_to_save ../EHRSHOT_ASSETS/figures

or

sbatch 6_make_figures_slurm.sh

Citation

If you find this project helpful, please cite our paper:

@article{wornow2023ehrshot,
      title={EHRSHOT: An EHR Benchmark for Few-Shot Evaluation of Foundation Models}, 
      author={Michael Wornow and Rahul Thapa and Ethan Steinberg and Jason Fries and Nigam Shah},
      year={2023},
      eprint={2307.02028},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

License

The source code of this repo is released under the Apache License 2.0. The model license are listed on their corresponding webpages.

About

A benchmark for few-shot evaluation of foundation models for electronic health records (EHRs)

https://ehrshot.stanford.edu

License:Apache License 2.0


Languages

Language:Jupyter Notebook 54.9%Language:Python 41.8%Language:Shell 3.3%