Multiple global conditioning values with a local dataset.
fred-dev opened this issue · comments
I am having issues with creating a config for using gloabl conditioning values. I cannot peice this together from the examples provided and am not sure where the issue comes from, I suspect it is in my training config.
I have audio recordings tabulated with environmental data, and want to use 8 floating point parameters for global conditioning using the 'adp_cfg_1d' conditional diffusion model.
I have prepared metadata in a single JSON file, but adjusted the custom metadata function to return a dictionary of all 8 parameters.
It would be great if anyone can shed some light on this.
Here is my config - I am using a pretrained autoencoder. I have added the conditioning details:
modelConfig = {
"model_type": "diffusion_cond",
"sample_size": 524288,
"sample_rate": 44100,
"audio_channels": 1,
"model": {
"io_channels": 64,
"diffusion": {
"type": "adp_cfg_1d",
"supports_global_cond": True,
"global_cond_ids": [
"latitude",
"longitude",
"temperature",
"humidity",
"wind_speed",
"pressure",
"minutes_of_day",
"day_of_year"
],
"config": {
"channels": 64,
"context_embedding_max_length": 8,
"context_embedding_features": 8,
"in_channels": 64,
"multipliers": [
1,
2,
4,
8,
16
],
"factors": [
1,
2,
2,
2
],
"num_blocks": [
1,
2,
2,
2
],
"attentions": [
0,
1,
1,
1,
1
],
"attention_heads": 8,
"attention_multiplier": 2
}
},
"pretransform": {
"type": "dac_pretrained",
"config": {}
},
"conditioning": {
"configs": [
{
"id": "latitude",
"type": "number",
"config": {
"min_val": 0.19656993333333334,
"max_val": 50.443667
}
},
{
"id": "longitude",
"type": "number",
"config": {
"min_val": 0.7689536111111112,
"max_val": 0.9665608333333334
}
},
{
"id": "temperature",
"type": "number",
"config": {
"min_val": -3.24085585279757,
"max_val": 3.6700142339763984
}
},
{
"id": "humidity",
"type": "number",
"config": {
"min_val": -3.8110683227566686,
"max_val": 1.4696676807777593
}
},
{
"id": "wind_speed",
"type": "number",
"config": {
"min_val": -1.6188376276678578,
"max_val": 11.608313462165052
}
},
{
"id": "pressure",
"type": "number",
"config": {
"min_val": -6.263994121817382,
"max_val": 5.390325035103388
}
},
{
"id": "minutes_of_day",
"type": "number",
"config": {
"min_val": 0,
"max_val": 0.9993055555555556
}
},
{
"id": "day_of_year",
"type": "number",
"config": {
"min_val": 0,
"max_val": 1.0027472527472527
}
}
],
"cond_dim": 8
}
},
"training": {
"learning_rate": 0.00004,
"demo": {
"demo_every": 1500,
"demo_steps": 100,
"num_demos": 3,
"demo_cfg_scales": [
1,
1,
1,
1,
1,
1,
1,
1
],
"demo_cond": [
{
"latitude": 0.337113507707468,
"longitude": 0.8998659802839629,
"temperature": 2.0902047266579145e-16,
"humidity": -5.249300186068055e-17,
"wind_speed": -1.7127625984753777e-16,
"pressure": -3.1158006270597913e-15,
"minutes_of_day": 0.3977426990645676,
"day_of_year": 0.5956888170306616
},
{
"latitude": 0.38461336306221783,
"longitude": 0.9281043879416349,
"temperature": 1.0000000000000002,
"humidity": 1,
"wind_speed": 0.9999999999999998,
"pressure": 0.9999999999999969,
"minutes_of_day": 0.6036759423548516,
"day_of_year": 0.8767792548963371
},
{
"latitude": 0.2896136523527182,
"longitude": 0.871627572626291,
"temperature": -0.9999999999999998,
"humidity": -1,
"wind_speed": -1.0000000000000002,
"pressure": -1.000000000000003,
"minutes_of_day": 0.19180945577428368,
"day_of_year": 0.31459837916498595
}
]
}
}
}
Here is my custom metadata file:
import json
# Load the metadata from the JSON file into a list for quick access.
metadata_file_path = 'path/to/all_metadata.json'
with open(metadata_file_path, 'r') as file:
audio_metadata_list = json.load(file)
def get_custom_metadata(info, audio):
# Extract the filename from the `info` parameter.
audio_filename = info.get("relpath", "").split('/')[-1].replace('_P.wav', '')
#find the entry in the JSON file.
metadata_entry = next((item for item in audio_metadata_list if item["filename"] == audio_filename), None)
# Default values for all keys.
metadata_for_entry = {
"latitude": 0.0,
"longitude": 0.0,
"temperature": 0.0,
"humidity": 0.0,
"wind_speed": 0.0,
"pressure": 0.0,
"minutes_of_day": 0.0,
"day_of_year": 0.0,
}
if metadata_entry:
for key in metadata_for_entry.keys():
if key in metadata_entry:
metadata_for_entry[key] = metadata_entry[key]
return metadata_for_entry
I get this error, the feedback and examples do not cover the required model config for conditional training like this, so it is difficult to debug.
File "./train.py", line 110, in main
trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 545, in fit
call._call_and_handle_interrupt(
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 581, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
results = self._run_stage()
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1036, in _run_stage
self.fit_loop.run()
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
self.advance()
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
self.epoch_loop.run(self._data_fetcher)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
self.advance(data_fetcher)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 240, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 187, in run
self._optimizer_step(batch_idx, closure)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 265, in _optimizer_step
call._call_lightning_module_hook(
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/core/module.py", line 1282, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 151, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 230, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/amp.py", line 77, in optimizer_step
closure_result = closure()
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
self._result = self.closure(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 126, in closure
step_output = self._step_fn()
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 315, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 382, in training_step
return self.lightning_module.training_step(*args, **kwargs)
File "/content/stable-audio-tools/stable_audio_tools/training/diffusion.py", line 406, in training_step
v = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0.1, **extra_args)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
result = forward_call(*args, **kwargs)
File "/content/stable-audio-tools/stable_audio_tools/models/diffusion.py", line 180, in forward
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/content/stable-audio-tools/stable_audio_tools/models/diffusion.py", line 225, in forward
outputs = self.model(
File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/content/stable-audio-tools/stable_audio_tools/models/adp.py", line 1280, in forward
b, device = embedding.shape[0], embedding.device
AttributeError: 'NoneType' object has no attribute 'shape'
On further investigation it seems the conditioning data is not forwarded through the UNet.
It flows to UNetCFG1d but the embedding is empty in the forward pass in adp.py line 1280.