PyTorch (Lightning) based implementation of Adapting Text Embeddings for Causal Inference
First, install dependencies
# clone project
git clone https://github.com/agoel00/causalBERT
# install project
cd causalBERT
pip install -r requirements.txt
Training CausalBERT
python run.py fit --accelerator gpu --batch_size 8
Inference using the trained CausalBERT checkpoint
python run.py predict --accelerator gpu --batch_size 8 --ckpt_path last
A lot of the training logic is taken from https://github.com/rpryzant/causal-bert-pytorch.