Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Logging Hyperparameters for list of dicts

vork opened this issue · comments

Bug description

Currently, when hyper parameters are logged with log_hyperparams the function calls _flatten_dict to collapse the dict to a single level. However, when the config contains a list of dicts, this gets flattened to a single string. Instead I would propose to log the list as [key/0/item, key/1/item] etc.

A fix could be simple:

def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> Dict[str, Any]:
    """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.

    Args:
        params: Dictionary containing the hyperparameters
        delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.

    Returns:
        Flattened dict.

    Examples:
        >>> _flatten_dict({'a': {'b': 'c'}})
        {'a/b': 'c'}
        >>> _flatten_dict({'a': {'b': 123}})
        {'a/b': 123}
        >>> _flatten_dict({5: {'a': 123}})
        {'5/a': 123}
        >>> _flatten_dict({"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]})
        {'dl/0/a': 1, 'dl/0/c': 3, 'dl/1/b': 2, 'dl/1/d': 5, 'l': [1, 2, 3, 4]}

    """
    result: Dict[str, Any] = {}
    for k, v in params.items():
        new_key = parent_key + delimiter + str(k) if parent_key else str(k)
        if is_dataclass(v):
            v = asdict(v)
        elif isinstance(v, Namespace):
            v = vars(v)

        if isinstance(v, MutableMapping):
            result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
        # Also handle the case where v is a list of dictionaries
        elif isinstance(v, list) and all(isinstance(item, MutableMapping) for item in v):
            for i, item in enumerate(v):
                result = {**result, **_flatten_dict(item, parent_key=f"{new_key}/{i}", delimiter=delimiter)}
        else:
            result[new_key] = v
    return result

What version are you seeing the problem on?

master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response