google / flax

Flax is a neural network library for JAX that is designed for flexibility.

Home Page:https://flax.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Incompatible variables for Tensorboard hparams are recast to strings but never returned

tttc3 opened this issue · comments

Core Problem

Tensorboard hparams only supports a subset of Python and Numpy variable types (see hparams docstrings). The flax.metrics.tensorboard.SummaryWriter class's method SummaryWriter.hparams() should handle this behavior via the flax.metrics.tensorboard._flatten_dict() function, casting incompatible types to strings (which hparams supports). However, despite performing the casting operation, the _flatten_dict function does not append the recast variables to the dictionary it returns.

The result, for the below example, is that the "hidden_layers" parameters are silently excluded and do not appear in Tensorboard's hparams.

from flax.metrics import tensorboard

experiment_dir = "./Example"

network_hyperparameters = {
    "hidden_layers_list": [12,12],
    "hidden_layers_tuple": (12,12),
    "dropout_rate": 1.0,
}

summary_writer = tensorboard.SummaryWriter(experiment_dir)
summary_writer.hparams(network_hyperparameters)
summary_writer.scalar('Training loss', 0.1, 1)
summary_writer.flush()

Colab Example:

Example notebook

Proposed fix

Modify _flattened_dict to explicitly check if a dictionary value is one of those supported by Tensorboard's hparams api, as defined here. If the value is not supported, cast it to a string and append it to the dictionary that _flattened_dict normally returns.

Current _flatten_dict code

def _flatten_dict(input_dict, parent_key='', sep='.'):
  """Flattens and simplifies dict such that it can be used by hparams.

  Args:
    input_dict: Input dict, e.g., from ConfigDict.
    parent_key: String used in recursion.
    sep: String used to separate parent and child keys.

  Returns:
   Flattened dict.
  """
  items = []
  for k, v in input_dict.items():
    new_key = parent_key + sep + k if parent_key else k

    # Take special care of things hparams cannot handle.
    if v is None:
      v = 'None'
    elif isinstance(v, list):
      v = str(v)
    elif isinstance(v, tuple):
      v = str(v)
    elif isinstance(v, dict):
      # Recursively flatten the dict.
      items.extend(_flatten_dict(v, new_key, sep=sep).items())
    else:
      items.append((new_key, v))
  return dict(items)

Proposed _flatten_dict code modification

def _flatten_dict(input_dict, parent_key='', sep='.'):
  """Flattens and simplifies dict such that it can be used by hparams.

  Args:
    input_dict: Input dict, e.g., from ConfigDict.
    parent_key: String used in recursion.
    sep: String used to separate parent and child keys.

  Returns:
   Flattened dict.
  """
  items = []
  for k, v in input_dict.items():
    new_key = parent_key + sep + k if parent_key else k

    # Valid types according to https://github.com/tensorflow/tensorboard/blob/1204566da5437af55109f7a4af18f9f8b7c4f864/tensorboard/plugins/hparams/summary_v2.py
    valid_types = (bool, int, float, str, np.bool_, np.integer, np.floating, np.character)

    if isinstance(v, dict):
      # Recursively flatten the dict.
      items.extend(_flatten_dict(v, new_key, sep=sep).items())
      continue
    elif not isinstance(v, valid_types):
      # Cast any incompatible values as strings such that they can be handled by hparams
      v = str(v)
    items.append((new_key, v))
  return dict(items)

I am happy submit a pull request with the modifications.

Thanks for noticing this. Indeed there seems to be a bug in our code, and we actually do nothing with v if it is None, list or tuple! Yes, it would be great if you could file this as a PR and I think your suggested change using valid_types is an improvement.

We should also run internals tests on this to make sure your change doesn't break anything.