CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support logging for non-scalar metrics

g-simmons opened this issue Β· comments

πŸš€ The feature, motivation, and pitch

AccelerateRLTrainer.evaluate() logs a table of generated eval outputs and metrics to the metrics tracker.

If I understand correctly, only scalar metrics are currently supported.

This feature would allow non-scalar metrics to be logged.

Use cases:

Allow passthrough logging of non-scalar prompt metadata:

    def eval_metrics(samples, prompts, outputs, **kwargs):
        return kwargs

    trlx.train(
        model_name,
        config=config,
        samples=train_samples,  # type: ignore
        rewards=train_rewards if training_method == "ilql" else None,  # type: ignore
        eval_prompts=eval_eval_prompts,  # type: ignore
        metric_fn=eval_metrics,
    ).model

Implementation suggestion:

Modify the mean_metrics calculation (below) to only calculate means for values that can be successfully cast to float tensors.

mean_metrics = {
f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item() for k, xs in metrics.items()
}

Alternatives

No response

Additional context

No response

hey, looks good. Do you want me to assign you?

Don't have a ton of time to work on it currently :/

I just did a simple try/except to get around it for now, if someone wants to develop it further that would be great.

mean_metrics = {}

for k, xs in metrics.items():
    try:
        mean_metrics[f"metrics/{k}{sweep_suffix}"] = torch.as_tensor(xs).mean(-1).item()
    except Exception:
        logger.warning(f"Metric {k} is not a scalar, skipping")
        continue