pytorch / torchtitan

A native PyTorch Library for large model training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

numerical issue when running SDPA with DTensor

tianyu-l opened this issue · comments

The issue comes from the backward computation of aten.mul of two complex numbers from DTensors: the result will be b + ai when it should be a + bi. Not sure why it happens -- when doing aten operations, the input tensors have been de-sugared and should have nothing to do with DTensor.

To replicate, put the following code in pytorch/test/distributed/tensor/parallel/test_tp_examples.py

    @with_comms
    def test_apply_rotary_embedding(self):
        device_mesh = self.build_device_mesh()
        def apply_rotary_emb(xq, freqs_cis):
            xq_ = torch.view_as_complex(xq)
            xq_out = torch.view_as_real(xq_ * freqs_cis)
            return xq_out
            
        with CommDebugMode():
            # xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
            # freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
            # xq_out = apply_rotary_emb(xq, freqs_cis)
            # xq_out.sum().backward()

            xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
            freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
            xq_dt = distribute_tensor(xq, device_mesh, (Replicate(),))
            freqs_cis_dt = distribute_tensor(freqs_cis, device_mesh, (Replicate(),))
            xq_out_dt = apply_rotary_emb(xq_dt, freqs_cis_dt)
            xq_out_dt.sum().backward()