Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Avoid casting with `numpy()` in `multiprocessing.py`

Peiffap opened this issue · comments

Outline & Motivation

Currently, get_extra_results() casts callback metrics to numpy to avoid problems with memory sharing:

callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
return {"callback_metrics": callback_metrics}

Then update_main_process_results() casts back to Tensor:
# NOTE: `get_extra_results` needs to be called before
callback_metrics = extra["callback_metrics"]
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))

It would be neater (and part of a greater goal of not depending on the numpy package, see #16649) to avoid this trick.

Pitch

Remove the cast to numpy without introducing errors. Remove the numpy depencency in multiprocessing.py.

Additional context

As discussed in #19841 (ref.).

cc @justusschock @awaelchli

I visited the code of get_extra_results() . It solely converts the torch tensors to detached numpy arrays. If we remove usage of numpy in it, then it means that function reduces to below body :

def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
        """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
        avoid issues with memory sharing, we cast the data to numpy.

        Args:
            trainer: reference to the Trainer.

        Returns:
            A dictionary with items to send back to the main process where :meth:`update_main_process_results` will
            process this output.

        """
        return {"callback_metrics": trainer.callback_metrics}

I think we can use simple Python instances like list , dict , etc ... to avoid using numpy ....
For that case, the better function to work with would be apply_to_collection

Pls share your thoughts....
Thanks