π[BUG]: ensemble perturbation_strategy: "spherical_grf"
ChenggongWang opened this issue Β· comments
Chenggong Wang commented
Version
source - main
On which installation method(s) does this occur?
No response
Describe the issue
When run ensemble forecast with perturbation_strategy: "spherical_grf", an error occured:
Traceback (most recent call last):
File "/workdir/earth2mip/examples/test_ensemble_inference.py", line 197, in <module>
inference_ensemble.main(config_str)
File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_ensemble.py", line 205, in main
run_inference(model, config, perturb, group)
File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_ensemble.py", line 382, in run_inference
run_ensembles(
File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_ensemble.py", line 104, in run_ensembles
x_start = perturb(x, rank, batch_id, model.device)
File "/usr/local/lib/python3.10/dist-packages/earth2mip/inference_ensemble.py", line 256, in perturb
return x + noise * scale[:, None, None]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Environment details
The error can be reproduced:
1. install the earth2mip from github with a pytorch docker image (nvcr.io/nvidia/pytorch:23.12-py3)
2. change the example perturbation_strategy to "spherical_grf" https://github.com/NVIDIA/earth2mip/blob/main/examples/01_ensemble_inference.py#L173