MeteoSwiss / ldcast

Latent diffusion for generative precipitation nowcasting

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

loading trained model

tomasvanoyen opened this issue · comments

Hi @jleinonen ,

thank you for this nice work.

I am trying to retrain the model using the script

python train_genforecast.py --model_dir="../models/genforecast_train

to see if I can reproduce the weights and obtain somewhat similar results. However, I am failing to load the model ckpt's into the Forecast class. Please note that loading the pretrained weights coming the Zenodo data repository does work.

Below I will provide the error message, but I also observed that the model size (on disk) is almost double for the ckpt's created by the train_genforecast.py script vs genforecast-radaronly-256x256-20step.pt.

I guess I am missing an obvious step here?

Error message:

from ldcast.forecast import Forecast
fn_aut = 'models/autoenc/autoenc-32-0.01.pt'
fn_gen = 'models/genforecast_train/epoch=0-val_loss_ema=0.6150.ckpt'
fc = Forecast(ldm_weights_fn = fn_gen, autoenc_weights_fn=fn_aut)
Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/workspace/thirdparty/meteoswiss/ldcast/forecast.py", line 49, in __init__ self.ldm = self._init_model() File "/workspace/thirdparty/meteoswiss/ldcast/forecast.py", line 99, in _init_model ldm.load_state_dict(torch.load(self.ldm_weights_fn)) File "/workspace/virtualenv/venv_ldcast/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LatentDiffusion: Missing key(s) in state_dict: "betas", "alphas_cumprod", "alphas_cumprod_prev", "sqrt_alphas_cumprod", "sqrt_one_minus_alphas_cumprod", "model.time_embed.0.weight", "model.time_embed.0.bias", "model.time_embed.2.weight", "model.time_embed.2.bias", "model.input_blocks.0.0.weight", "model.input_blocks.0.0.bias", "model.input_blocks.1.0.in_layers.2.weight", "model.input_blocks.1.0.in_layers.2.bias", "model.input_blocks.1.0.emb_layers.1.weight", "model.input_blocks.1.0.emb_layers.1.bias", "model.input_blocks.1.0.out_layers.3.weight", "model.input_blocks.1.0.out_layers.3.bias", "model.input_blocks.1.1.pre_proj.weight", "model.input_blocks.1.1.pre_proj.bias", "model.input_blocks.1.1.filter.w1", "model.input_blocks.1.1.filter.b1", "model.input_blocks.1.1.filter.w2", "model.input_blocks.1.1.filter.b2", "model.input_blocks.1.1.mlp.fc1.weight", "model.input_blocks.1.1.mlp.fc1.bias", "model.input_blocks.1.1.mlp.fc2.weight", "model.input_blocks.1.1.mlp.fc2.bias", "model.input_blocks.2.0.in_layers.2.weight", "model.input_blocks.2.0.in_layers.2.bias", "model.input_blocks.2.0.emb_layers.1.weight", "model.input_blocks.2.0.emb_layers.1.bias", "model.input_blocks.2.0.out_layers.3.weight", "model.input_blocks.2.0.out_layers.3.bias", "model.input_blocks.2.1.pre_proj.weight", "model.input_blocks.2.1.pre_proj.bias", "model.input_blocks.2.1.filter.w1", "model.input_blocks.2.1.filter.b1", "model.input_blocks.2.1.filter.w2", "model.input_blocks.2.1.filter.b2", "model.input_blocks.2.1.mlp.fc1.weight", "model.input_blocks.2.1.mlp.fc1.bias", "model.input_blocks.2.1.mlp.fc2.weight", "model.input_blocks.2.1.mlp.fc2.bias", "model.input_blocks.3.0.op.weight", "model.input_blocks.3.0.op.bias", "model.input_blocks.4.0.in_layers.2.weight", "model.input_blocks.4.0.in_layers.2.bias", "model.input_blocks.4.0.emb_layers.1.weight", "model.input_blocks.4.0.emb_layers.1.bias", "model.input_blocks.4.0.out_layers.3.weight", "model.input_blocks.4.0.out_layers.3.bias", "model.input_blocks.4.0.skip_connection.weight", "model.input_blocks.4.0.skip_connection.bias", "model.input_blocks.4.1.pre_proj.weight", "model.input_blocks.4.1.pre_proj.bias", "model.input_blocks.4.1.filter.w1", "model.input_blocks.4.1.filter.b1", "model.input_blocks.4.1.filter.w2", "model.input_blocks.4.1.filter.b2", "model.input_blocks.4.1.mlp.fc1.weight", "model.input_blocks.4.1.mlp.fc1.bias", "model.input_blocks.4.1.mlp.fc2.weight", "model.input_blocks.4.1.mlp.fc2.bias", "model.input_blocks.5.0.in_layers.2.weight", "model.input_blocks.5.0.in_layers.2.bias", "model.input_blocks.5.0.emb_layers.1.weight", "model.input_blocks.5.0.emb_layers.1.bias", "model.input_blocks.5.0.out_layers.3.weight", "model.input_blocks.5.0.out_layers.3.bias", "model.input_blocks.5.1.pre_proj.weight", "model.input_blocks.5.1.pre_proj.bias", "model.input_blocks.5.1.filter.w1", "model.input_blocks.5.1.filter.b1", "model.input_blocks.5.1.filter.w2", "model.input_blocks.5.1.filter.b2", "model.input_blocks.5.1.mlp.fc1.weight", "model.input_blocks.5.1.mlp.fc1.bias", "model.input_blocks.5.1.mlp.fc2.weight", "model.input_blocks.5.1.mlp.fc2.bias", "model.input_blocks.6.0.op.weight", "model.input_blocks.6.0.op.bias", "model.input_blocks.7.0.in_layers.2.weight", "model.input_blocks.7.0.in_layers.2.bias", "model.input_blocks.7.0.emb_layers.1.weight", "model.input_blocks.7.0.emb_layers.1.bias", "model.input_blocks.7.0.out_layers.3.weight", "model.input_blocks.7.0.out_layers.3.bias", "model.input_blocks.7.0.skip_connection.weight", "model.input_blocks.7.0.skip_connection.bias", "model.input_blocks.8.0.in_layers.2.weight", "model.input_blocks.8.0.in_layers.2.bias", "model.input_blocks.8.0.emb_layers.1.weight", "model.input_blocks.8.0.emb_layers.1.bias", "model.input_blocks.8.0.out_layers.3.weight", "model.input_blocks.8.0.out_layers.3.bias", "model.middle_block.0.in_layers.2.weight", "model.middle_block.0.in_layers.2.bias", "model.middle_block.0.emb_layers.1.weight", "model.middle_block.0.emb_layers.1.bias", "model.middle_block.0.out_layers.3.weight", "model.middle_block.0.out_layers.3.bias", "model.middle_block.1.pre_proj.weight", "model.middle_block.1.pre_proj.bias", "model.middle_block.1.filter.w1", "model.middle_block.1.filter.b1", "model.middle_block.1.filter.w2", "model.middle_block.1.filter.b2", "model.middle_block.1.mlp.fc1.weight", "model.middle_block.1.mlp.fc1.bias", "model.middle_block.1.mlp.fc2.weight", "model.middle_block.1.mlp.fc2.bias", "model.middle_block.2.in_layers.2.weight", "model.middle_block.2.in_layers.2.bias", "model.middle_block.2.emb_layers.1.weight", "model.middle_block.2.emb_layers.1.bias", "model.middle_block.2.out_layers.3.weight", "model.middle_block.2.out_layers.3.bias", "model.output_blocks.0.0.in_layers.2.weight", "model.output_blocks.0.0.in_layers.2.bias", "model.output_blocks.0.0.emb_layers.1.weight", "model.output_blocks.0.0.emb_layers.1.bias", "model.output_blocks.0.0.out_layers.3.weight", "model.output_blocks.0.0.out_layers.3.bias", "model.output_blocks.0.0.skip_connection.weight", "model.output_blocks.0.0.skip_connection.bias", "model.output_blocks.1.0.in_layers.2.weight", "model.output_blocks.1.0.in_layers.2.bias", "model.output_blocks.1.0.emb_layers.1.weight", "model.output_blocks.1.0.emb_layers.1.bias", "model.output_blocks.1.0.out_layers.3.weight", "model.output_blocks.1.0.out_layers.3.bias", "model.output_blocks.1.0.skip_connection.weight", "model.output_blocks.1.0.skip_connection.bias", "model.output_blocks.2.0.in_layers.2.weight", "model.output_blocks.2.0.in_layers.2.bias", "model.output_blocks.2.0.emb_layers.1.weight", "model.output_blocks.2.0.emb_layers.1.bias", "model.output_blocks.2.0.out_layers.3.weight", "model.output_blocks.2.0.out_layers.3.bias", "model.output_blocks.2.0.skip_connection.weight", "model.output_blocks.2.0.skip_connection.bias", "model.output_blocks.2.1.conv.weight", "model.output_blocks.2.1.conv.bias", "model.output_blocks.3.0.in_layers.2.weight", "model.output_blocks.3.0.in_layers.2.bias", "model.output_blocks.3.0.emb_layers.1.weight", "model.output_blocks.3.0.emb_layers.1.bias", "model.output_blocks.3.0.out_layers.3.weight", "model.output_blocks.3.0.out_layers.3.bias", "model.output_blocks.3.0.skip_connection.weight", "model.output_blocks.3.0.skip_connection.bias", "model.output_blocks.3.1.pre_proj.weight", "model.output_blocks.3.1.pre_proj.bias", "model.output_blocks.3.1.filter.w1", "model.output_blocks.3.1.filter.b1", "model.output_blocks.3.1.filter.w2", "model.output_blocks.3.1.filter.b2", "model.output_blocks.3.1.mlp.fc1.weight", "model.output_blocks.3.1.mlp.fc1.bias", "model.output_blocks.3.1.mlp.fc2.weight", "model.output_blocks.3.1.mlp.fc2.bias", "model.output_blocks.4.0.in_layers.2.weight", "model.output_blocks.4.0.in_layers.2.bias", "model.output_blocks.4.0.emb_layers.1.weight", "model.output_blocks.4.0.emb_layers.1.bias", "model.output_blocks.4.0.out_layers.3.weight", "model.output_blocks.4.0.out_layers.3.bias", "model.output_blocks.4.0.skip_connection.weight", "model.output_blocks.4.0.skip_connection.bias", "model.output_blocks.4.1.pre_proj.weight", "model.output_blocks.4.1.pre_proj.bias", "model.output_blocks.4.1.filter.w1", "model.output_blocks.4.1.filter.b1", "model.output_blocks.4.1.filter.w2", "model.output_blocks.4.1.filter.b2", "model.output_blocks.4.1.mlp.fc1.weight", "model.output_blocks.4.1.mlp.fc1.bias", "model.output_blocks.4.1.mlp.fc2.weight", "model.output_blocks.4.1.mlp.fc2.bias", "model.output_blocks.5.0.in_layers.2.weight", "model.output_blocks.5.0.in_layers.2.bias", "model.output_blocks.5.0.emb_layers.1.weight", "model.output_blocks.5.0.emb_layers.1.bias", "model.output_blocks.5.0.out_layers.3.weight", "model.output_blocks.5.0.out_layers.3.bias", "model.output_blocks.5.0.skip_connection.weight", "model.output_blocks.5.0.skip_connection.bias", "model.output_blocks.5.1.pre_proj.weight", "model.output_blocks.5.1.pre_proj.bias", "model.output_blocks.5.1.filter.w1", "model.output_blocks.5.1.filter.b1", "model.output_blocks.5.1.filter.w2", "model.output_blocks.5.1.filter.b2", "model.output_blocks.5.1.mlp.fc1.weight", "model.output_blocks.5.1.mlp.fc1.bias", "model.output_blocks.5.1.mlp.fc2.weight", "model.output_blocks.5.1.mlp.fc2.bias", "model.output_blocks.5.2.conv.weight", "model.output_blocks.5.2.conv.bias", "model.output_blocks.6.0.in_layers.2.weight", "model.output_blocks.6.0.in_layers.2.bias", "model.output_blocks.6.0.emb_layers.1.weight", "model.output_blocks.6.0.emb_layers.1.bias", "model.output_blocks.6.0.out_layers.3.weight", "model.output_blocks.6.0.out_layers.3.bias", "model.output_blocks.6.0.skip_connection.weight", "model.output_blocks.6.0.skip_connection.bias", "model.output_blocks.6.1.pre_proj.weight", "model.output_blocks.6.1.pre_proj.bias", "model.output_blocks.6.1.filter.w1", "model.output_blocks.6.1.filter.b1", "model.output_blocks.6.1.filter.w2", "model.output_blocks.6.1.filter.b2", "model.output_blocks.6.1.mlp.fc1.weight", "model.output_blocks.6.1.mlp.fc1.bias", "model.output_blocks.6.1.mlp.fc2.weight", "model.output_blocks.6.1.mlp.fc2.bias", "model.output_blocks.7.0.in_layers.2.weight", "model.output_blocks.7.0.in_layers.2.bias", "model.output_blocks.7.0.emb_layers.1.weight", "model.output_blocks.7.0.emb_layers.1.bias", "model.output_blocks.7.0.out_layers.3.weight", "model.output_blocks.7.0.out_layers.3.bias", "model.output_blocks.7.0.skip_connection.weight", "model.output_blocks.7.0.skip_connection.bias", "model.output_blocks.7.1.pre_proj.weight", "model.output_blocks.7.1.pre_proj.bias", "model.output_blocks.7.1.filter.w1", "model.output_blocks.7.1.filter.b1", "model.output_blocks.7.1.filter.w2", "model.output_blocks.7.1.filter.b2", "model.output_blocks.7.1.mlp.fc1.weight", "model.output_blocks.7.1.mlp.fc1.bias", "model.output_blocks.7.1.mlp.fc2.weight", "model.output_blocks.7.1.mlp.fc2.bias", "model.output_blocks.8.0.in_layers.2.weight", "model.output_blocks.8.0.in_layers.2.bias", "model.output_blocks.8.0.emb_layers.1.weight", "model.output_blocks.8.0.emb_layers.1.bias", "model.output_blocks.8.0.out_layers.3.weight", "model.output_blocks.8.0.out_layers.3.bias", "model.output_blocks.8.0.skip_connection.weight", "model.output_blocks.8.0.skip_connection.bias", "model.output_blocks.8.1.pre_proj.weight", "model.output_blocks.8.1.pre_proj.bias", "model.output_blocks.8.1.filter.w1", "model.output_blocks.8.1.filter.b1", "model.output_blocks.8.1.filter.w2", "model.output_blocks.8.1.filter.b2", "model.output_blocks.8.1.mlp.fc1.weight", "model.output_blocks.8.1.mlp.fc1.bias", "model.output_blocks.8.1.mlp.fc2.weight", "model.output_blocks.8.1.mlp.fc2.bias", "model.out.2.weight", "model.out.2.bias", "autoencoder.log_var", "autoencoder.encoder.0.proj.weight", "autoencoder.encoder.0.proj.bias", "autoencoder.encoder.0.conv1.weight", "autoencoder.encoder.0.conv1.bias", "autoencoder.encoder.0.conv2.weight", "autoencoder.encoder.0.conv2.bias", "autoencoder.encoder.0.norm1.weight", "autoencoder.encoder.0.norm1.bias", "autoencoder.encoder.0.norm2.weight", "autoencoder.encoder.0.norm2.bias", "autoencoder.encoder.1.weight", "autoencoder.encoder.1.bias", "autoencoder.encoder.2.conv1.weight", "autoencoder.encoder.2.conv1.bias", "autoencoder.encoder.2.conv2.weight", "autoencoder.encoder.2.conv2.bias", "autoencoder.encoder.2.norm1.weight", "autoencoder.encoder.2.norm1.bias", "autoencoder.encoder.2.norm2.weight", "autoencoder.encoder.2.norm2.bias", "autoencoder.encoder.3.weight", "autoencoder.encoder.3.bias", "autoencoder.decoder.0.weight", "autoencoder.decoder.0.bias", "autoencoder.decoder.1.conv1.weight", "autoencoder.decoder.1.conv1.bias", "autoencoder.decoder.1.conv2.weight", "autoencoder.decoder.1.conv2.bias", "autoencoder.decoder.1.norm1.weight", "autoencoder.decoder.1.norm1.bias", "autoencoder.decoder.1.norm2.weight", "autoencoder.decoder.1.norm2.bias", "autoencoder.decoder.2.weight", "autoencoder.decoder.2.bias", "autoencoder.decoder.3.proj.weight", "autoencoder.decoder.3.proj.bias", "autoencoder.decoder.3.conv1.weight", "autoencoder.decoder.3.conv1.bias", "autoencoder.decoder.3.conv2.weight", "autoencoder.decoder.3.conv2.bias", "autoencoder.decoder.3.norm1.weight", "autoencoder.decoder.3.norm1.bias", "autoencoder.decoder.3.norm2.weight", "autoencoder.decoder.3.norm2.bias", "autoencoder.to_moments.weight", "autoencoder.to_moments.bias", "autoencoder.to_decoder.weight", "autoencoder.to_decoder.bias", "context_encoder.autoencoder.0.log_var", "context_encoder.autoencoder.0.encoder.0.proj.weight", "context_encoder.autoencoder.0.encoder.0.proj.bias", "context_encoder.autoencoder.0.encoder.0.conv1.weight", "context_encoder.autoencoder.0.encoder.0.conv1.bias", "context_encoder.autoencoder.0.encoder.0.conv2.weight", "context_encoder.autoencoder.0.encoder.0.conv2.bias", "context_encoder.autoencoder.0.encoder.0.norm1.weight", "context_encoder.autoencoder.0.encoder.0.norm1.bias", "context_encoder.autoencoder.0.encoder.0.norm2.weight", "context_encoder.autoencoder.0.encoder.0.norm2.bias", "context_encoder.autoencoder.0.encoder.1.weight", "context_encoder.autoencoder.0.encoder.1.bias", "context_encoder.autoencoder.0.encoder.2.conv1.weight", "context_encoder.autoencoder.0.encoder.2.conv1.bias", "context_encoder.autoencoder.0.encoder.2.conv2.weight", "context_encoder.autoencoder.0.encoder.2.conv2.bias", "context_encoder.autoencoder.0.encoder.2.norm1.weight", "context_encoder.autoencoder.0.encoder.2.norm1.bias", "context_encoder.autoencoder.0.encoder.2.norm2.weight", "context_encoder.autoencoder.0.encoder.2.norm2.bias", "context_encoder.autoencoder.0.encoder.3.weight", "context_encoder.autoencoder.0.encoder.3.bias", "context_encoder.autoencoder.0.decoder.0.weight", "context_encoder.autoencoder.0.decoder.0.bias", "context_encoder.autoencoder.0.decoder.1.conv1.weight", "context_encoder.autoencoder.0.decoder.1.conv1.bias", "context_encoder.autoencoder.0.decoder.1.conv2.weight", "context_encoder.autoencoder.0.decoder.1.conv2.bias", "context_encoder.autoencoder.0.decoder.1.norm1.weight", "context_encoder.autoencoder.0.decoder.1.norm1.bias", "context_encoder.autoencoder.0.decoder.1.norm2.weight", "context_encoder.autoencoder.0.decoder.1.norm2.bias", "context_encoder.autoencoder.0.decoder.2.weight", "context_encoder.autoencoder.0.decoder.2.bias", "context_encoder.autoencoder.0.decoder.3.proj.weight", "context_encoder.autoencoder.0.decoder.3.proj.bias", "context_encoder.autoencoder.0.decoder.3.conv1.weight", "context_encoder.autoencoder.0.decoder.3.conv1.bias", "context_encoder.autoencoder.0.decoder.3.conv2.weight", "context_encoder.autoencoder.0.decoder.3.conv2.bias", "context_encoder.autoencoder.0.decoder.3.norm1.weight", "context_encoder.autoencoder.0.decoder.3.norm1.bias", "context_encoder.autoencoder.0.decoder.3.norm2.weight", "context_encoder.autoencoder.0.decoder.3.norm2.bias", "context_encoder.autoencoder.0.to_moments.weight", "context_encoder.autoencoder.0.to_moments.bias", "context_encoder.autoencoder.0.to_decoder.weight", "context_encoder.autoencoder.0.to_decoder.bias", "context_encoder.proj.0.weight", "context_encoder.proj.0.bias", "context_encoder.analysis.0.0.norm1.weight", "context_encoder.analysis.0.0.norm1.bias", "context_encoder.analysis.0.0.filter.w1", "context_encoder.analysis.0.0.filter.b1", "context_encoder.analysis.0.0.filter.w2", "context_encoder.analysis.0.0.filter.b2", "context_encoder.analysis.0.0.norm2.weight", "context_encoder.analysis.0.0.norm2.bias", "context_encoder.analysis.0.0.mlp.fc1.weight", "context_encoder.analysis.0.0.mlp.fc1.bias", "context_encoder.analysis.0.0.mlp.fc2.weight", "context_encoder.analysis.0.0.mlp.fc2.bias", "context_encoder.analysis.0.1.norm1.weight", "context_encoder.analysis.0.1.norm1.bias", "context_encoder.analysis.0.1.filter.w1", "context_encoder.analysis.0.1.filter.b1", "context_encoder.analysis.0.1.filter.w2", "context_encoder.analysis.0.1.filter.b2", "context_encoder.analysis.0.1.norm2.weight", "context_encoder.analysis.0.1.norm2.bias", "context_encoder.analysis.0.1.mlp.fc1.weight", "context_encoder.analysis.0.1.mlp.fc1.bias", "context_encoder.analysis.0.1.mlp.fc2.weight", "context_encoder.analysis.0.1.mlp.fc2.bias", "context_encoder.analysis.0.2.norm1.weight", "context_encoder.analysis.0.2.norm1.bias", "context_encoder.analysis.0.2.filter.w1", "context_encoder.analysis.0.2.filter.b1", "context_encoder.analysis.0.2.filter.w2", "context_encoder.analysis.0.2.filter.b2", "context_encoder.analysis.0.2.norm2.weight", "context_encoder.analysis.0.2.norm2.bias", "context_encoder.analysis.0.2.mlp.fc1.weight", "context_encoder.analysis.0.2.mlp.fc1.bias", "context_encoder.analysis.0.2.mlp.fc2.weight", "context_encoder.analysis.0.2.mlp.fc2.bias", "context_encoder.analysis.0.3.norm1.weight", "context_encoder.analysis.0.3.norm1.bias", "context_encoder.analysis.0.3.filter.w1", "context_encoder.analysis.0.3.filter.b1", "context_encoder.analysis.0.3.filter.w2", "context_encoder.analysis.0.3.filter.b2", "context_encoder.analysis.0.3.norm2.weight", "context_encoder.analysis.0.3.norm2.bias", "context_encoder.analysis.0.3.mlp.fc1.weight", "context_encoder.analysis.0.3.mlp.fc1.bias", "context_encoder.analysis.0.3.mlp.fc2.weight", "context_encoder.analysis.0.3.mlp.fc2.bias", "context_encoder.temporal_transformer.0.attn1.KV.weight", "context_encoder.temporal_transformer.0.attn1.KV.bias", "context_encoder.temporal_transformer.0.attn1.Q.weight", "context_encoder.temporal_transformer.0.attn1.Q.bias", "context_encoder.temporal_transformer.0.attn1.proj.weight", "context_encoder.temporal_transformer.0.attn1.proj.bias", "context_encoder.temporal_transformer.0.attn2.KV.weight", "context_encoder.temporal_transformer.0.attn2.KV.bias", "context_encoder.temporal_transformer.0.attn2.Q.weight", "context_encoder.temporal_transformer.0.attn2.Q.bias", "context_encoder.temporal_transformer.0.attn2.proj.weight", "context_encoder.temporal_transformer.0.attn2.proj.bias", "context_encoder.temporal_transformer.0.norm1.weight", "context_encoder.temporal_transformer.0.norm1.bias", "context_encoder.temporal_transformer.0.norm2.weight", "context_encoder.temporal_transformer.0.norm2.bias", "context_encoder.temporal_transformer.0.norm3.weight", "context_encoder.temporal_transformer.0.norm3.bias", "context_encoder.temporal_transformer.0.mlp.0.weight", "context_encoder.temporal_transformer.0.mlp.0.bias", "context_encoder.temporal_transformer.0.mlp.2.weight", "context_encoder.temporal_transformer.0.mlp.2.bias", "context_encoder.forecast.0.norm1.weight", "context_encoder.forecast.0.norm1.bias", "context_encoder.forecast.0.filter.w1", "context_encoder.forecast.0.filter.b1", "context_encoder.forecast.0.filter.w2", "context_encoder.forecast.0.filter.b2", "context_encoder.forecast.0.norm2.weight", "context_encoder.forecast.0.norm2.bias", "context_encoder.forecast.0.mlp.fc1.weight", "context_encoder.forecast.0.mlp.fc1.bias", "context_encoder.forecast.0.mlp.fc2.weight", "context_encoder.forecast.0.mlp.fc2.bias", "context_encoder.forecast.1.norm1.weight", "context_encoder.forecast.1.norm1.bias", "context_encoder.forecast.1.filter.w1", "context_encoder.forecast.1.filter.b1", "context_encoder.forecast.1.filter.w2", "context_encoder.forecast.1.filter.b2", "context_encoder.forecast.1.norm2.weight", "context_encoder.forecast.1.norm2.bias", "context_encoder.forecast.1.mlp.fc1.weight", "context_encoder.forecast.1.mlp.fc1.bias", "context_encoder.forecast.1.mlp.fc2.weight", "context_encoder.forecast.1.mlp.fc2.bias", "context_encoder.forecast.2.norm1.weight", "context_encoder.forecast.2.norm1.bias", "context_encoder.forecast.2.filter.w1", "context_encoder.forecast.2.filter.b1", "context_encoder.forecast.2.filter.w2", "context_encoder.forecast.2.filter.b2", "context_encoder.forecast.2.norm2.weight", "context_encoder.forecast.2.norm2.bias", "context_encoder.forecast.2.mlp.fc1.weight", "context_encoder.forecast.2.mlp.fc1.bias", "context_encoder.forecast.2.mlp.fc2.weight", "context_encoder.forecast.2.mlp.fc2.bias", "context_encoder.forecast.3.norm1.weight", "context_encoder.forecast.3.norm1.bias", "context_encoder.forecast.3.filter.w1", "context_encoder.forecast.3.filter.b1", "context_encoder.forecast.3.filter.w2", "context_encoder.forecast.3.filter.b2", "context_encoder.forecast.3.norm2.weight", "context_encoder.forecast.3.norm2.bias", "context_encoder.forecast.3.mlp.fc1.weight", "context_encoder.forecast.3.mlp.fc1.bias", "context_encoder.forecast.3.mlp.fc2.weight", "context_encoder.forecast.3.mlp.fc2.bias", "context_encoder.resnet.0.proj.weight", "context_encoder.resnet.0.proj.bias", "context_encoder.resnet.0.conv1.weight", "context_encoder.resnet.0.conv1.bias", "context_encoder.resnet.0.conv2.weight", "context_encoder.resnet.0.conv2.bias", "context_encoder.resnet.1.proj.weight", "context_encoder.resnet.1.proj.bias", "context_encoder.resnet.1.conv1.weight", "context_encoder.resnet.1.conv1.bias", "context_encoder.resnet.1.conv2.weight", "context_encoder.resnet.1.conv2.bias", "model_ema.decay", "model_ema.num_updates", "model_ema.time_embed0weight", "model_ema.time_embed0bias", "model_ema.time_embed2weight", "model_ema.time_embed2bias", "model_ema.input_blocks00weight", "model_ema.input_blocks00bias", "model_ema.input_blocks10in_layers2weight", "model_ema.input_blocks10in_layers2bias", "model_ema.input_blocks10emb_layers1weight", "model_ema.input_blocks10emb_layers1bias", "model_ema.input_blocks10out_layers3weight", "model_ema.input_blocks10out_layers3bias", "model_ema.input_blocks11pre_projweight", "model_ema.input_blocks11pre_projbias", "model_ema.input_blocks11filterw1", "model_ema.input_blocks11filterb1", "model_ema.input_blocks11filterw2", "model_ema.input_blocks11filterb2", "model_ema.input_blocks11mlpfc1weight", "model_ema.input_blocks11mlpfc1bias", "model_ema.input_blocks11mlpfc2weight", "model_ema.input_blocks11mlpfc2bias", "model_ema.input_blocks20in_layers2weight", "model_ema.input_blocks20in_layers2bias", "model_ema.input_blocks20emb_layers1weight", "model_ema.input_blocks20emb_layers1bias", "model_ema.input_blocks20out_layers3weight", "model_ema.input_blocks20out_layers3bias", "model_ema.input_blocks21pre_projweight", "model_ema.input_blocks21pre_projbias", "model_ema.input_blocks21filterw1", "model_ema.input_blocks21filterb1", "model_ema.input_blocks21filterw2", "model_ema.input_blocks21filterb2", "model_ema.input_blocks21mlpfc1weight", "model_ema.input_blocks21mlpfc1bias", "model_ema.input_blocks21mlpfc2weight", "model_ema.input_blocks21mlpfc2bias", "model_ema.input_blocks30opweight", "model_ema.input_blocks30opbias", "model_ema.input_blocks40in_layers2weight", "model_ema.input_blocks40in_layers2bias", "model_ema.input_blocks40emb_layers1weight", "model_ema.input_blocks40emb_layers1bias", "model_ema.input_blocks40out_layers3weight", "model_ema.input_blocks40out_layers3bias", "model_ema.input_blocks40skip_connectionweight", "model_ema.input_blocks40skip_connectionbias", "model_ema.input_blocks41pre_projweight", "model_ema.input_blocks41pre_projbias", "model_ema.input_blocks41filterw1", "model_ema.input_blocks41filterb1", "model_ema.input_blocks41filterw2", "model_ema.input_blocks41filterb2", "model_ema.input_blocks41mlpfc1weight", "model_ema.input_blocks41mlpfc1bias", "model_ema.input_blocks41mlpfc2weight", "model_ema.input_blocks41mlpfc2bias", "model_ema.input_blocks50in_layers2weight", "model_ema.input_blocks50in_layers2bias", "model_ema.input_blocks50emb_layers1weight", "model_ema.input_blocks50emb_layers1bias", "model_ema.input_blocks50out_layers3weight", "model_ema.input_blocks50out_layers3bias", "model_ema.input_blocks51pre_projweight", "model_ema.input_blocks51pre_projbias", "model_ema.input_blocks51filterw1", "model_ema.input_blocks51filterb1", "model_ema.input_blocks51filterw2", "model_ema.input_blocks51filterb2", "model_ema.input_blocks51mlpfc1weight", "model_ema.input_blocks51mlpfc1bias", "model_ema.input_blocks51mlpfc2weight", "model_ema.input_blocks51mlpfc2bias", "model_ema.input_blocks60opweight", "model_ema.input_blocks60opbias", "model_ema.input_blocks70in_layers2weight", "model_ema.input_blocks70in_layers2bias", "model_ema.input_blocks70emb_layers1weight", "model_ema.input_blocks70emb_layers1bias", "model_ema.input_blocks70out_layers3weight", "model_ema.input_blocks70out_layers3bias", "model_ema.input_blocks70skip_connectionweight", "model_ema.input_blocks70skip_connectionbias", "model_ema.input_blocks80in_layers2weight", "model_ema.input_blocks80in_layers2bias", "model_ema.input_blocks80emb_layers1weight", "model_ema.input_blocks80emb_layers1bias", "model_ema.input_blocks80out_layers3weight", "model_ema.input_blocks80out_layers3bias", "model_ema.middle_block0in_layers2weight", "model_ema.middle_block0in_layers2bias", "model_ema.middle_block0emb_layers1weight", "model_ema.middle_block0emb_layers1bias", "model_ema.middle_block0out_layers3weight", "model_ema.middle_block0out_layers3bias", "model_ema.middle_block1pre_projweight", "model_ema.middle_block1pre_projbias", "model_ema.middle_block1filterw1", "model_ema.middle_block1filterb1", "model_ema.middle_block1filterw2", "model_ema.middle_block1filterb2", "model_ema.middle_block1mlpfc1weight", "model_ema.middle_block1mlpfc1bias", "model_ema.middle_block1mlpfc2weight", "model_ema.middle_block1mlpfc2bias", "model_ema.middle_block2in_layers2weight", "model_ema.middle_block2in_layers2bias", "model_ema.middle_block2emb_layers1weight", "model_ema.middle_block2emb_layers1bias", "model_ema.middle_block2out_layers3weight", "model_ema.middle_block2out_layers3bias", "model_ema.output_blocks00in_layers2weight", "model_ema.output_blocks00in_layers2bias", "model_ema.output_blocks00emb_layers1weight", "model_ema.output_blocks00emb_layers1bias", "model_ema.output_blocks00out_layers3weight", "model_ema.output_blocks00out_layers3bias", "model_ema.output_blocks00skip_connectionweight", "model_ema.output_blocks00skip_connectionbias", "model_ema.output_blocks10in_layers2weight", "model_ema.output_blocks10in_layers2bias", "model_ema.output_blocks10emb_layers1weight", "model_ema.output_blocks10emb_layers1bias", "model_ema.output_blocks10out_layers3weight", "model_ema.output_blocks10out_layers3bias", "model_ema.output_blocks10skip_connectionweight", "model_ema.output_blocks10skip_connectionbias", "model_ema.output_blocks20in_layers2weight", "model_ema.output_blocks20in_layers2bias", "model_ema.output_blocks20emb_layers1weight", "model_ema.output_blocks20emb_layers1bias", "model_ema.output_blocks20out_layers3weight", "model_ema.output_blocks20out_layers3bias", "model_ema.output_blocks20skip_connectionweight", "model_ema.output_blocks20skip_connectionbias", "model_ema.output_blocks21convweight", "model_ema.output_blocks21convbias", "model_ema.output_blocks30in_layers2weight", "model_ema.output_blocks30in_layers2bias", "model_ema.output_blocks30emb_layers1weight", "model_ema.output_blocks30emb_layers1bias", "model_ema.output_blocks30out_layers3weight", "model_ema.output_blocks30out_layers3bias", "model_ema.output_blocks30skip_connectionweight", "model_ema.output_blocks30skip_connectionbias", "model_ema.output_blocks31pre_projweight", "model_ema.output_blocks31pre_projbias", "model_ema.output_blocks31filterw1", "model_ema.output_blocks31filterb1", "model_ema.output_blocks31filterw2", "model_ema.output_blocks31filterb2", "model_ema.output_blocks31mlpfc1weight", "model_ema.output_blocks31mlpfc1bias", "model_ema.output_blocks31mlpfc2weight", "model_ema.output_blocks31mlpfc2bias", "model_ema.output_blocks40in_layers2weight", "model_ema.output_blocks40in_layers2bias", "model_ema.output_blocks40emb_layers1weight", "model_ema.output_blocks40emb_layers1bias", "model_ema.output_blocks40out_layers3weight", "model_ema.output_blocks40out_layers3bias", "model_ema.output_blocks40skip_connectionweight", "model_ema.output_blocks40skip_connectionbias", "model_ema.output_blocks41pre_projweight", "model_ema.output_blocks41pre_projbias", "model_ema.output_blocks41filterw1", "model_ema.output_blocks41filterb1", "model_ema.output_blocks41filterw2", "model_ema.output_blocks41filterb2", "model_ema.output_blocks41mlpfc1weight", "model_ema.output_blocks41mlpfc1bias", "model_ema.output_blocks41mlpfc2weight", "model_ema.output_blocks41mlpfc2bias", "model_ema.output_blocks50in_layers2weight", "model_ema.output_blocks50in_layers2bias", "model_ema.output_blocks50emb_layers1weight", "model_ema.output_blocks50emb_layers1bias", "model_ema.output_blocks50out_layers3weight", "model_ema.output_blocks50out_layers3bias", "model_ema.output_blocks50skip_connectionweight", "model_ema.output_blocks50skip_connectionbias", "model_ema.output_blocks51pre_projweight", "model_ema.output_blocks51pre_projbias", "model_ema.output_blocks51filterw1", "model_ema.output_blocks51filterb1", "model_ema.output_blocks51filterw2", "model_ema.output_blocks51filterb2", "model_ema.output_blocks51mlpfc1weight", "model_ema.output_blocks51mlpfc1bias", "model_ema.output_blocks51mlpfc2weight", "model_ema.output_blocks51mlpfc2bias", "model_ema.output_blocks52convweight", "model_ema.output_blocks52convbias", "model_ema.output_blocks60in_layers2weight", "model_ema.output_blocks60in_layers2bias", "model_ema.output_blocks60emb_layers1weight", "model_ema.output_blocks60emb_layers1bias", "model_ema.output_blocks60out_layers3weight", "model_ema.output_blocks60out_layers3bias", "model_ema.output_blocks60skip_connectionweight", "model_ema.output_blocks60skip_connectionbias", "model_ema.output_blocks61pre_projweight", "model_ema.output_blocks61pre_projbias", "model_ema.output_blocks61filterw1", "model_ema.output_blocks61filterb1", "model_ema.output_blocks61filterw2", "model_ema.output_blocks61filterb2", "model_ema.output_blocks61mlpfc1weight", "model_ema.output_blocks61mlpfc1bias", "model_ema.output_blocks61mlpfc2weight", "model_ema.output_blocks61mlpfc2bias", "model_ema.output_blocks70in_layers2weight", "model_ema.output_blocks70in_layers2bias", "model_ema.output_blocks70emb_layers1weight", "model_ema.output_blocks70emb_layers1bias", "model_ema.output_blocks70out_layers3weight", "model_ema.output_blocks70out_layers3bias", "model_ema.output_blocks70skip_connectionweight", "model_ema.output_blocks70skip_connectionbias", "model_ema.output_blocks71pre_projweight", "model_ema.output_blocks71pre_projbias", "model_ema.output_blocks71filterw1", "model_ema.output_blocks71filterb1", "model_ema.output_blocks71filterw2", "model_ema.output_blocks71filterb2", "model_ema.output_blocks71mlpfc1weight", "model_ema.output_blocks71mlpfc1bias", "model_ema.output_blocks71mlpfc2weight", "model_ema.output_blocks71mlpfc2bias", "model_ema.output_blocks80in_layers2weight", "model_ema.output_blocks80in_layers2bias", "model_ema.output_blocks80emb_layers1weight", "model_ema.output_blocks80emb_layers1bias", "model_ema.output_blocks80out_layers3weight", "model_ema.output_blocks80out_layers3bias", "model_ema.output_blocks80skip_connectionweight", "model_ema.output_blocks80skip_connectionbias", "model_ema.output_blocks81pre_projweight", "model_ema.output_blocks81pre_projbias", "model_ema.output_blocks81filterw1", "model_ema.output_blocks81filterb1", "model_ema.output_blocks81filterw2", "model_ema.output_blocks81filterb2", "model_ema.output_blocks81mlpfc1weight", "model_ema.output_blocks81mlpfc1bias", "model_ema.output_blocks81mlpfc2weight", "model_ema.output_blocks81mlpfc2bias", "model_ema.out2weight", "model_ema.out2bias". Unexpected key(s) in state_dict: "epoch", "global_step", "pytorch-lightning_version", "state_dict", "loops", "callbacks", "optimizer_states", "lr_schedulers".

The solution to the above is found by loading the state_dict of the checkpoint and not the entire checkpoint.

E.g.:

import torch
fn_ckpt = '../models/---.ckpt'
fn_ckpt_state_dict = '../models/state_dict_---.ckpt'
ckpt = torch.load(fn_ckpt)
torch.save(ckpt['state_dict'], fn_ckpt_state_dict )

Then the following works:

from ldcast.forecast import Forecast
fn_aut = 'models/autoenc/autoenc-32-0.01.pt'
fn_ckpt_state_dict = '../models/state_dict_---.ckpt'
fc = Forecast(ldm_weights_fn = fn_gen, autoenc_weights_fn=fn_ckpt_state_dict)

Hi @tomasvanoyen, thanks for figuring it out!