Stability-AI / stable-audio-tools

Generative models for conditional audio generation

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Multiple global conditioning values with a local dataset.

fred-dev opened this issue · comments

I am having issues with creating a config for using gloabl conditioning values. I cannot peice this together from the examples provided and am not sure where the issue comes from, I suspect it is in my training config.

I have audio recordings tabulated with environmental data, and want to use 8 floating point parameters for global conditioning using the 'adp_cfg_1d' conditional diffusion model.

I have prepared metadata in a single JSON file, but adjusted the custom metadata function to return a dictionary of all 8 parameters.

It would be great if anyone can shed some light on this.

Here is my config - I am using a pretrained autoencoder. I have added the conditioning details:

modelConfig = {
  "model_type": "diffusion_cond",
  "sample_size": 524288,
  "sample_rate": 44100,
  "audio_channels": 1,
  "model": {
    "io_channels": 64,
    "diffusion": {
      "type": "adp_cfg_1d",
      "supports_global_cond": True,
      "global_cond_ids": [
        "latitude",
        "longitude",
        "temperature",
        "humidity",
        "wind_speed",
        "pressure",
        "minutes_of_day",
        "day_of_year"
      ],
    
      "config": {
        "channels": 64,
        "context_embedding_max_length": 8,
        "context_embedding_features": 8,
        "in_channels": 64,
        "multipliers": [
          1,
          2,
          4,
          8,
          16
        ],
        "factors": [
          1,
          2,
          2,
          2
        ],
        "num_blocks": [
          1,
          2,
          2,
          2
        ],
        "attentions": [
          0,
          1,
          1,
          1,
          1
        ],
        "attention_heads": 8,
        "attention_multiplier": 2
      }
    },
    "pretransform": {
      "type": "dac_pretrained",
      "config": {}
    },
    "conditioning": {
      "configs": [
        {
          "id": "latitude",
          "type": "number",
          "config": {
            "min_val": 0.19656993333333334,
            "max_val": 50.443667
          }
        },
        {
          "id": "longitude",
          "type": "number",
          "config": {
            "min_val": 0.7689536111111112,
            "max_val": 0.9665608333333334
          }
        },
        {
          "id": "temperature",
          "type": "number",
          "config": {
            "min_val": -3.24085585279757,
            "max_val": 3.6700142339763984
          }
        },
        {
          "id": "humidity",
          "type": "number",
          "config": {
            "min_val": -3.8110683227566686,
            "max_val": 1.4696676807777593
          }
        },
        {
          "id": "wind_speed",
          "type": "number",
          "config": {
            "min_val": -1.6188376276678578,
            "max_val": 11.608313462165052
          }
        },
        {
          "id": "pressure",
          "type": "number",
          "config": {
            "min_val": -6.263994121817382,
            "max_val": 5.390325035103388
          }
        },
        {
          "id": "minutes_of_day",
          "type": "number",
          "config": {
            "min_val": 0,
            "max_val": 0.9993055555555556
          }
        },
        {
          "id": "day_of_year",
          "type": "number",
          "config": {
            "min_val": 0,
            "max_val": 1.0027472527472527
          }
        }
      ],
      "cond_dim": 8

    }
  },
  "training": {
    "learning_rate": 0.00004,
    "demo": {
      "demo_every": 1500,
      "demo_steps": 100,
      "num_demos": 3,
      "demo_cfg_scales": [
        1,
        1,
        1,
        1,
        1,
        1,  
        1,
        1
      ],
      "demo_cond": [
        {
          "latitude": 0.337113507707468,
          "longitude": 0.8998659802839629,
          "temperature": 2.0902047266579145e-16,
          "humidity": -5.249300186068055e-17,
          "wind_speed": -1.7127625984753777e-16,
          "pressure": -3.1158006270597913e-15,
          "minutes_of_day": 0.3977426990645676,
          "day_of_year": 0.5956888170306616
        },
        {
          "latitude": 0.38461336306221783,
          "longitude": 0.9281043879416349,
          "temperature": 1.0000000000000002,
          "humidity": 1,
          "wind_speed": 0.9999999999999998,
          "pressure": 0.9999999999999969,
          "minutes_of_day": 0.6036759423548516,
          "day_of_year": 0.8767792548963371
        },
        {
          "latitude": 0.2896136523527182,
          "longitude": 0.871627572626291,
          "temperature": -0.9999999999999998,
          "humidity": -1,
          "wind_speed": -1.0000000000000002,
          "pressure": -1.000000000000003,
          "minutes_of_day": 0.19180945577428368,
          "day_of_year": 0.31459837916498595
        }
      ]
    }
  }
}

Here is my custom metadata file:

import json

# Load the metadata from the JSON file into a list for quick access.
metadata_file_path = 'path/to/all_metadata.json'
with open(metadata_file_path, 'r') as file:
    audio_metadata_list = json.load(file) 

def get_custom_metadata(info, audio):
    # Extract the filename from the `info` parameter.
    audio_filename = info.get("relpath", "").split('/')[-1].replace('_P.wav', '')
    #find the entry in the JSON file.
    metadata_entry = next((item for item in audio_metadata_list if item["filename"] == audio_filename), None)

    # Default values for all keys.
    metadata_for_entry = {
        "latitude": 0.0,
        "longitude": 0.0,
        "temperature": 0.0,
        "humidity": 0.0,
        "wind_speed": 0.0,
        "pressure": 0.0,
        "minutes_of_day": 0.0,
        "day_of_year": 0.0,
    }

    if metadata_entry:
        for key in metadata_for_entry.keys():
            if key in metadata_entry:
                metadata_for_entry[key] = metadata_entry[key]

    return metadata_for_entry

I get this error, the feedback and examples do not cover the required model config for conditional training like this, so it is difficult to debug.

File "./train.py", line 110, in main
    trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 545, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 581, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
    results = self._run_stage()
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1036, in _run_stage
    self.fit_loop.run()
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
    self.advance()
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
    self.advance(data_fetcher)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 240, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 187, in run
    self._optimizer_step(batch_idx, closure)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 265, in _optimizer_step
    call._call_lightning_module_hook(
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/core/module.py", line 1282, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 151, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 230, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/amp.py", line 77, in optimizer_step
    closure_result = closure()
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 140, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 126, in closure
    step_output = self._step_fn()
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/automatic.py", line 315, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 382, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
  File "/content/stable-audio-tools/stable_audio_tools/training/diffusion.py", line 406, in training_step
    v = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0.1, **extra_args)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/content/stable-audio-tools/stable_audio_tools/models/diffusion.py", line 180, in forward
    return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/stable-audio-tools/stable_audio_tools/models/diffusion.py", line 225, in forward
    outputs = self.model(
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/envs/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/stable-audio-tools/stable_audio_tools/models/adp.py", line 1280, in forward
    b, device = embedding.shape[0], embedding.device
AttributeError: 'NoneType' object has no attribute 'shape'

On further investigation it seems the conditioning data is not forwarded through the UNet.
It flows to UNetCFG1d but the embedding is empty in the forward pass in adp.py line 1280.