Cells2Vec: Bridging The Gap Between Simulations And Experiments Using Causal Representation Learning -- Code
This is the code corresponding to the experiments conducted for the work "Cells2Vec: Bridging the gap between experiments and simulations using causal representation learning" where we used Causal Representation Learning to learn useful representations from high dimensional simulations and used Experiments to validate these representations, as well as estimate input parameters using Regression.
Experiments were done with the following package versions for Python 3.10:
-
Numpy (
numpy
) v1.21.5; -
Matplotlib (
matplotlib
) v3.5.3; -
Pandas (
pandas
) v2.0.1; -
CellModeller (
cellmodeller-ingallslab
) vx.x.x( To read raw simulations into torch); -
PyTorch (
torch
) v1.13.0 with CUDA 11.0; -
Scikit-learn (
sklearn
) v1.3.0; -
XGBoost (
xgboost
) v1.7.5.
This code should execute correctly with updated versions of these packages. Use the requirments.txt
file to install these, except the cellmodeller-ingallslab
package.
Below is one of many ways to setup a Virtual environment: After cloning the repo, using a terminal session
mkdir env
cd env
virtualenv .
cd ..
source env/bin/activate
pip install -r requirements.txt
We used 1000
simulations by sampling parameters Gamma
, Reg_Param
and Adhesion
(Ref: CellModeller documentation) here.
We had 100
parameter sets and 10
simulations for each set.
Our dataset is available for download here, move it to the Data
directory after downloading.
losses.py
file: implements the triplet loss with custom distance functions, as well as regularization for the same;networks.py
file: implements encoder and its building blocks (dilated convolutions, causal CNN) as well as LSTM and GRU encoders;data_utils.py
file: implements custom PyTorch datasets, and code to sample triplets iteratively, and unravel a padded tensor, and read a raw simulation into a tensor;configs/configs.yaml
file: example of a YAML file containing the hyperparameters of a complete run of the code;eval.py
file: Code to calculate K-Means clustering metrics, and fit a XGBoost Regression model on embeddings of a trained encoder to estimate parameters;trainer.py
file: Code for model trainingvisualize.py
file: Code for generating KMeans+PCA and TSNE Plots.main.py
file: Wrapper file for an end to end run.sim2data.py
file: Reads raw simulations into torch tensors. Forn
iterations (directories) ofk
simulations, groups the first file from all directories together, then the second ...k
regression_analysis.ipynb
notebook: Contains code for Pearson correlation tests, analysis of Residuals and plots for the expected vs observed parameters.sims_vs_exps.ipynb
notebook: Example code to compute Similarity values to compareexperiment embeddings
and correspondingsimulation embeddings
.
checkpoints
directory: Plots and CSV results of regression will be generated here, along with the model checkpoint(s).runs
directory: Tensorboard logs saved here
To train a model using default hyperparameters and to evaluate:
python3 main.py --config configs/configs.yaml --selected_config default
See the code documentation for more details. main.py
can be called with the
-h
option for additional help.
Hyperparameters for the Encoder can be found here.
Hyperparameters for training:
num_samples
: Total size of training dataset. Random triplets will be sampled to output the training dataset;num_val
: Number of classes to exclude from the training process (To test model's ability to generalize);val_indices
: Manually select classes to exclude from training (Random if not specified).split_idx
: (Seedata_utils.py
) Splits the training set into a training and validation (early stopping condition) sets.
Please consider citing our work using the following bibtex entry,
@inproceedings{
cells2vec,
title={Cells2Vec: Bridging the gap between experiments and simulations using causal representation learning},
author={{Dhruva, Rajwade, and Ahmadi, Atiyeh, and Ingalls, Brian},
booktitle={Causal Representation Learning Workshop at NeurIPS 2023},
year={2023},
url={https://openreview.net/forum?id=O9jfSs82XU}
}