sxjscience / PreDiff

Official implementation of PreDiff

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

PreDiff

By Zhihan Gao, Xingjian Shi, Boran Han, Hao Wang, Xiaoyong Jin, Danielle Maddix Robinson, Yi Zhu, Yuyang Bernie Wang, Mu Li, Dit-Yan Yeung.

This repo is the official implementation of "PreDiff: Precipitation Nowcasting with Latent Diffusion Models" that will appear in NeurIPS 2023.

Introduction

Earth system forecasting has traditionally relied on complex physical models that are computationally expensive and require significant domain expertise. In the past decade, the unprecedented increase in spatiotemporal Earth observation data has enabled data-driven forecasting models using deep learning techniques. These models have shown promise for diverse Earth system forecasting tasks but either struggle with handling uncertainty or neglect domain-specific prior knowledge, resulting in averaging possible futures to blurred forecasts or generating physically implausible predictions. To address these limitations, we propose a two-stage pipeline for probabilistic spatiotemporal forecasting:

  1. We develop PreDiff, a conditional latent diffusion model capable of probabilistic forecasts.
  2. We incorporate an explicit knowledge alignment mechanism to align forecasts with domain-specific physical constraints. This is achieved by estimating the deviation from imposed constraints at each denoising step and adjusting the transition distribution accordingly.

We conduct empirical studies on two datasets: N-body MNIST, a synthetic dataset with chaotic behavior, and SEVIR, a real-world precipitation nowcasting dataset. Specifically, we impose the law of conservation of energy in N-body MNIST and anticipated precipitation intensity in SEVIR. Experiments demonstrate the effectiveness of PreDiff in handling uncertainty, incorporating domain-specific prior knowledge, and generating forecasts that exhibit high operational utility.

teaser

Overview of PreDiff inference with knowledge alignment. An observation sequence $y$ is encoded into a latent context $z_{\text{cond}}$ by the frame-wise encoder $\mathcal{E}$. The latent diffusion model $p_\theta(z_t|z_{t+1}, z_{\text{cond}})$, which is parameterized by an Earthformer-UNet, then generates the latent future $z_0$ by autoregressively denoising Gaussian noise $z_T$ conditioned on $z_{\text{cond}}$. It takes the concatenation of the latent context $z_{\text{cond}}$ and the previous-step noisy latent future $z_{t+1}$ as input, and outputs $z_t$. The transition distribution of each step from $z_{t+1}$ to $z_t$ can be further refined via knowledge alignment, according to auxiliary prior knowledge. $z_0$ is decoded back to pixel space by the frame-wise decoder $\mathcal{D}$ to produce the final prediction $\hat{x}$.

Installation

Create Conda environment

conda create --name prediff python=3.10.12
conda activate prediff

Install PyTorch and PyTorch-Lightning with correct CUDA support

python -m pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install lightning==2.0.9

Install PreDiff in dev mode

cd ROOT_DIR/PreDiff
python -m pip install -e . --no-build-isolation

Datasets

Storm EVent ImageRy (SEVIR) dataset is a spatiotemporally aligned dataset containing over 10,000 weather events. The NEXRAD Vertically Integrated Liquid (VIL) mosaics in SEVIR are adopted for benchmarking precipitation nowcasting, i.e., to predict the future VIL up to 60 minutes given 65 minutes context VIL. The resolution is thus $13\times 384\times 384\rightarrow 12\times 384\times 384$. We adopted a downsampled version of SEVIR, denoted as SEVIR-LR, where the temporal downscaling factor is 2, and the spatial downscaling factor is 3 for each dimension. On SEVIR-LR dataset, PreDiff generates $6\times 128\times 128$ forecasts for a given $7\times 128\times 128$ context sequence.

A visualization example of SEVIR VIL sequence: Example_SEVIR_VIL_sequence

To download SEVIR-LR dataset directly from AWS S3, run:

cd ROOT_DIR/PreDiff
python ./scripts/datasets/sevir/download_sevirlr.py

We can also let the SEVIRLightningDataModule do it for you automatically the first time you call prepare_data().

Alternatively, if you already have the original SEVIR dataset, you may want to get SEVIR-LR by downsampling the original SEVIR. In this case run:

cd ROOT_DIR/PreDiff
ln -s path_to_SEVIR ./datasets/sevir  # link to your SEVIR dataset if it is not in `ROOT_DIR/PreDiff/datasets`
python ./scripts/datasets/sevir/downsample_sevir.py

Training Script and Pretrained Models

Test pretrained PreDiff

Run the following command to download all pretrained weights in advance. Use --model flag to download a specific pretrained model component. The available candidates are vae, earthformerunet, alignment, and all.

cd ROOT_DIR/PreDiff
python ./scripts/download_pretrained.py --model all

Run the following commands to load pretrained models for inference on SEVIR-LR dataset, following the instruction.

cd ROOT_DIR/PreDiff
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/prediff/sevirlr/train_sevirlr_prediff.py --gpus 2 --pretrained --save tmp_sevirlr_prediff

The results will be saved to directory ROOT_DIR/PreDiff/experiments/tmp_sevirlr_prediff.

Notice that since the inference is extremely time-consuming, the inference is only done for those example sequences for visualization. To evaluate the whole val/test sets, please set vis.eval_example_only: false in the config.

Train from scratch

Our two-stage pipeline sequentially trains PreDiff and the knowledge alignment network. The training of PreDiff is further decomposed into two sequential phases: training the VAE and the latent Earthformer-UNet. To train all components from scratch, follow these sequential steps:

  1. Train the VAE.
  2. Train the latent Earthformer-UNet with the VAE trained in step 1.
  3. Train the knowledge alignment network with the VAE trained in step 1.

In practice, the training of the knowledge alignment network is independent of the training of the latent Earthformer-UNet. Therefore, steps 2 and 3 can be performed in parallel. To achieve this, specify the path to the PyTorch state_dict of the VAE trained in step 1 by setting vae.pretrained_ckpt_path in the corresponding config files.

Find detailed instructions in how to train the models or running inference with our pretrained models in the corresponding script folder.

Model Component Script Folder Config
VAE scripts config
Latent Earthformer-UNet scripts config
Knowledge Alignment Network scripts config

Citing PreDiff

@inproceedings{gao2023prediff,
  title={PreDiff: Precipitation Nowcasting with Latent Diffusion Models},
  author={Gao, Zhihan and Shi, Xingjian and Han, Boran and Wang, Hao and Jin, Xiaoyong and Robinson, Danielle and Zhu, Yi and Wang, Yuyang and Li, Mu and Yeung, Dit-Yan},
  booktitle={NeurIPS},
  year={2023}
}

Credits

Third-party libraries:

License

This project is licensed under the Apache-2.0 License.

About

Official implementation of PreDiff

License:Apache License 2.0


Languages

Language:Python 100.0%