huggingface / nanotron

Minimalistic large language model 3D-parallelism training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] Fix clipping gradients's test

xrsrke opened this issue · comments

The failed CI/CD: https://github.com/huggingface/nanotron/actions/runs/8116090525/job/22185337029

FAILED tests/test_clip_grads.py::test_clip_grads_with_pp[1.0]

tests/helpers/utils.py:221: in _run_until_success
    ret = func(*args, **kwargs)
tests/test_clip_grads.py:37: in test_clip_grads_with_pp
    init_distributed(tp=1, dp=1, pp=2)(_test_clip_grads_with_pp)(norm_type=norm_type)
tests/helpers/utils.py:272: in wrapper
    mp.spawn(global_wrapper, args=args, nprocs=world_size)
/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/spawn.py:246: in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/spawn.py:202: in start_processes
    while not context.join():
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch.multiprocessing.spawn.ProcessContext object at 0x7f2f74e1c730>
timeout = None

    def join(self, timeout=None):
        r"""
        Tries to join one or more processes in this spawn context.
        If one of them exited with a non-zero exit status, this function
        kills the remaining processes and raises an exception with the cause
        of the first process exiting.
    
        Returns ``True`` if all processes have been joined successfully,
        ``False`` if there are more processes that need to be joined.
    
        Args:
            timeout (float): Wait this long before giving up on waiting.
        """
        # Ensure this function can be called even when we're done.
        if len(self.sentinels) == 0:
            return True
    
        # Wait for any process to fail or all of them to succeed.
        ready = multiprocessing.connection.wait(
            self.sentinels.keys(),
            timeout=timeout,
        )
    
        error_index = None
        for sentinel in ready:
            index = self.sentinels.pop(sentinel)
            process = self.processes[index]
            process.join()
            if process.exitcode != 0:
                error_index = index
                break
    
        # Return if there was no error.
        if error_index is None:
            # Return whether or not all processes have been joined.
            return len(self.sentinels) == 0
    
        # Assume failure. Terminate processes that are still alive.
        for process in self.processes:
            if process.is_alive():
                process.terminate()
            process.join()
    
        # There won't be an error on the queue if the process crashed.
        failed_process = self.processes[error_index]
        if self.error_queues[error_index].empty():
            exitcode = self.processes[error_index].exitcode
            if exitcode < 0:
                name = signal.Signals(-exitcode).name
                raise ProcessExitedException(
                    "process %d terminated with signal %s" % (error_index, name),
                    error_index=error_index,
                    error_pid=failed_process.pid,
                    exit_code=exitcode,
                    signal_name=name,
                )
            else:
                raise ProcessExitedException(
                    "process %d terminated with exit code %d" % (error_index, exitcode),
                    error_index=error_index,
                    error_pid=failed_process.pid,
                    exit_code=exitcode,
                )
    
        original_trace = self.error_queues[error_index].get()
        msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
        msg += original_trace
>       raise ProcessRaisedException(msg, error_index, failed_process.pid)
E       torch.multiprocessing.spawn.ProcessRaisedException: 
E       
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
E           fn(i, *args)
E         File "/__w/nanotron/nanotron/tests/helpers/utils.py", line 259, in global_wrapper
E           func(parallel_context, **kwargs)
E         File "/__w/nanotron/nanotron/tests/test_clip_grads.py", line 151, in _test_clip_grads_with_pp
E           torch.testing.assert_close(total_norm, reference_total_norm, atol=1e-6, rtol=1e-7)
E         File "/usr/local/lib/python3.10/dist-packages/torch/testing/_comparison.py", line 1520, in assert_close
E           raise error_metas[0].to_error(msg)
E       AssertionError: Scalars are not close!
E       
E       Expected 193.16995239257812 but got 193.169921875.
E       Absolute difference: 3.0[517](https://github.com/huggingface/nanotron/actions/runs/8116090525/job/22185337029#step:9:518)578125e-05 (up to 1e-06 allowed)
E       Relative difference: 1.5798304936670125e-07 (up to 1e-07 allowed)

/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/spawn.py:163: ProcessRaisedException