huggingface / evaluate

🤗 Evaluate: A library for easily evaluating machine learning models and datasets.

Home Page:https://huggingface.co/docs/evaluate

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Understanding metric compute in DDP setup

Natooz opened this issue · comments

Hello,

I am having a bit of trouble handling metrics in DDP setup.
I red the doc, especially the part about DDP but couldn't find much details of how to handle this. I find it a bit confusing, what does the "computing the final metric on the first node" mean ?

With two processes, when computing two metrics here is what .compute returns (first nb is process id):

1 None
1 None
0 {'accuracy': 0.009848218392359514}
0 {'f1': 0.49228395061728397}

So what's happening ? If I understand well, all processes (on each node? for all nodes combined? doc says "first node", its process right?) others than the first write predictions and reference in cache. The first process is waiting for the predictions and references from all other nodes, and when it gathered them all computes then returns the metric results. I am correct?

If so:

  1. how do we make sure that other processes provide preds and refs ? Is .compute doing it? .add?
  2. When using it with Trainer, is it ok to return 0s for processes (and nodes?) other than the first?

Here is how my compute_metrics function looks like, is it ok?

def compute_metrics_pt(eval_pred, metrics_: Dict[str, Metric]):
    (predictions_mlm, predictions_nsp), (labels_mlm, labels_nsp) = eval_pred

    # preprocess tensors
    pad_mask = labels_mlm != -100
    labels_mlm, predictions = labels_mlm[pad_mask], predictions_mlm[pad_mask]

    # compute metrics
    acc = metrics_["accuracy"].compute(predictions=predictions, references=labels_mlm)
    f1 = metrics_["f1"].compute(predictions=predictions_nsp.flatten(), references=labels_nsp.flatten(), average="micro")

    # first process
    if metrics_["accuracy"].process_id == 0:  # SHOULD CHECK THAT IT'S FIRST NODE TOO?
        metric_res = {"accuracy_mlm": acc["accuracy"], "f1_nsp": f1["f1"]}

    # other processes
    else:
        metric_res = {"accuracy_mlm": 0, "f1_nsp": 0}
    return metric_res

Found the answers in the old documentation.