pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)

Home Page:https://pytorch.org/xla

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Functionalization] Dynamo tests

wonjoolee95 opened this issue · comments

Issue for tracking the two dynamo tests failing with functionalization. All the dynamo tests (https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo.py) are now failing:

  1. test_simple_model
  2. test_resnet18

The test test_simple_model failing with:

(base) jenkins@26d7adccbc26:/workspace/pytorch/xla$ TORCH_SHOW_DISPATCH_TRACE=1 python test/dynamo/test_dynamo.py -k test_simple_model
======================================================================
FAIL: test_simple_model (__main__.DynamoBasicTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test/dynamo/test_dynamo.py", line 42, in test_simple_model
    self.assertNotIn('xla::add', met.counter_names())
AssertionError: 'xla::add' unexpectedly found in ['CreateXlaTensor', 'DestroyLtcTensor', 'DestroyXlaTensor', 'xla::add', 'xla::cos', 'xla::sin']

----------------------------------------------------------------------
Ran 1 test in 0.207s

FAILED (failures=1)

The first assertion self.assertIn('xla::add', met.counter_names()) at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo.py#L37 succeeds, but the second assertion self.assertNotIn('xla::add', met.counter_names()) at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo.py#L42 fails with the error message above.

The full error logs with pytorch dispatch trace is at https://gist.github.com/wonjoolee95/f47a50fdfc72ca585694ce1ce9290c14. My guess would be that the cache is being missed somehow, but need to look into it further.

Just noting some observations:

Updating with some dumps. On functionalization branch, the IR dump looks like:

res_xla_dynamo: IR {
  %0 = f32[] xla::device_data(), location=test_simple_model@test_dynamo.py:40, xla_shape=f32[], device=CPU:0, ROOT=0
}

res_xla_dynamo_2: IR {
  %0 = f32[] prim::Constant(), location=fn_simple@test_dynamo.py:19, xla_shape=f32[], value=1
  %1 = f32[] xla::device_data(), location=test_simple_model@test_dynamo.py:35, xla_shape=f32[], device=CPU:0
  %2 = f32[] aten::sin(%1), location=fn_simple@test_dynamo.py:18, xla_shape=f32[]
  %3 = f32[] aten::mul(%2, %0), location=fn_simple@test_dynamo.py:19, xla_shape=f32[]
  %4 = f32[] xla::device_data(), location=test_simple_model@test_dynamo.py:34, xla_shape=f32[], device=CPU:0
  %5 = f32[] aten::cos(%4), location=fn_simple@test_dynamo.py:17, xla_shape=f32[]
  %6 = f32[] aten::add(%5, %3), location=fn_simple@test_dynamo.py:19, xla_shape=f32[], ROOT=0
}

Full HLO dump can be found at https://gist.github.com/wonjoolee95/57e12cb6947e56538e0cde23b261021f.


On master branch, the IR dump looks like:

res_xla_dynamo: IR {
  %0 = f32[] xla::device_data(), location=test_simple_model@test_dynamo.py:40, xla_shape=f32[], device=CPU:0, ROOT=0
}

res_xla_dynamo_2: IR {
  %0 = f32[] xla::device_data(), location=test_simple_model@test_dynamo.py:45, xla_shape=f32[], device=CPU:0, ROOT=0
}

Full HLO dump can be found at https://gist.github.com/wonjoolee95/57e12cb6947e56538e0cde23b261021f.

Okay, writing down some more findings. With functionalization branch, it looks like this test_simple_model test case doesn't even enter the the dynamo-xla integration anymore. The code flow doesn't reach the functions at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/torchxla_integration.py.

As opposed to running with master branch with breakpoints set at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/torchxla_integration.py#L205 and https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py#L761, I could the code gets stopped as expected. However, with the same breakpoints and functionalization branch, the breakpoints do not get invoked.

This difference seems related to the warning message that gets thrown during the test_simple_model test case with functionalization branch:

[2023-01-09 09:08:37,385] torch._dynamo.utils: [WARNING] Unsupported: meta converter nyi with fake tensor propagation.

This is thrown at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/utils.py#L718. This warning is only shown with the functionalization branch. And after this warning message, the code in the test case does not reach the torchxla_integration.py.

@ezyang any idea if functionization will work with fake tensor?

[2023-01-09 09:08:37,385] torch._dynamo.utils: [WARNING] Unsupported: meta converter nyi with fake tensor propagation.

I think what we can do is to disable the Fake Tensor(I think it is enabled by defualt) for pytorch/xla dynamo bridge and see if that solved the issue. Let me look up how to do that.

@wonjoolee95 Maybe in your installed pytorch, modified https://github.com/pytorch/pytorch/blob/944519a46823eb4e95d05a6849c040050247eacc/torch/_functorch/config.py#L15 to false. I am a bit confuse through, I thought FAKE_TENSOR is only relevant for training. The test we had only includes inference.

Thanks for the info, Jack. I tried setting use_fake_tensor = False, however I'm still seeing the same warning message and the dynamo-xla integration still doesn't seem to get reached (breakpoints are not trigerring).

Can you pdb at

[2023-01-09 09:08:37,385] torch._dynamo.utils: [WARNING] Unsupported: meta converter nyi with fake tensor propagation.

and tell me what the input tensor's type is. Is it an XLA tensor? A fake tensor? If it's a fake tensor, who fakeified it?

Thank you @ezyang for the input.

With pdb, I can see that the tensor seems to be an XLA tensor. From where this warning message is raised, the call stack comes from the self.wrap_tensor(value) call at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/variables/builder.py#L235.

The relevant stack trace looks like (the full stack trace is at https://gist.github.com/wonjoolee95/fc0aca31ef90cca77d782f81f7fa8911):

-> return self._wrap(value).clone(**self.options())
VariableBuilder.__call__ (at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/variables/builder.py#L169)
-> return self.wrap_tensor(value)
VariableBuilder.__wrap (at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/variables/builder.py#L235)
-> tensor_variable = wrap_fx_proxy(
VariableBuilder.wrap_tensor (at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/variables/builder.py#L616)
-> return wrap_fx_proxy_cls(
wrap_fx_proxy (at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/variables/builder.py#L731)
-> example_value = wrap_to_fake_tensor_and_record(
wrap_fx_proxy_cls (at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/variables/builder.py#L791)
-> fake_e = wrap_fake_exception(
wrap_to_fake_tensor_and_record (at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/variables/builder.py#L942)
-> raise unimplemented(msg) from e
wrap_fake_exception (at https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/utils.py#L720)

And FWIW, the locals() with pdb at this point are:

{'tx': <torch._dynamo.symbolic_convert.InstructionTranslator object at 0x7effd10857d0>, 
'proxy': Proxy(x), 'example_value': tensor(100., device='xla:0'), 
'options': {'guards': {Guard(name='x', source=<GuardSource.LOCAL: 0>, 
create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7effd84c4e60>, 
is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, 
guarded_class_weakref=None)}, 'should_specialize': False, 'ignore_subclass': False, 
'source': LocalSource(local_name='x')}}

OK so I kind of know what's going on here but I'm not sure exactly what to do. Let me first tell you what's going on.

When we Dynamo code, traditionally, we are given a bunch of CPU/CUDA tensors, and we need to be able to dry run the operators to find out what the intermediate shapes are, so we can, e.g., do things like keep going even if a user says if intermediate.size(0) == 10. To do this, we turn the tensor into a fake tensor, and then run the ops on the fake tensor. Internally, fakeification proceeds by making a Meta tensor, and then turning that into the fake tensor subclass.

Now, this logic is all heavily biased towards the CPU/CUDA case. But with XLA, obviously you will have an XLA tensor. Now, ordinarily this would work, but I think XLA tensors report torch._is_functional_tensor(t) is true, and so you hit this case in meta conversion

        if (    
            type(t) is torch.Tensor
            or type(t) is torch.nn.Parameter
            or (ignore_subclass and isinstance(t, torch.Tensor))
            or isinstance(t, FakeTensor)
        ):      
            if any(
                [
                    t.is_sparse_csr,
                    t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
                    t.is_quantized,
                    t.is_nested,
                    t._is_view() and t._base is not None and t._base.is_sparse,
                    torch._is_functional_tensor(t),
                    # these are supported in meta conversion but the fallbacks
                    # don't work
                    t.is_neg(),
                    t.is_conj(),
                    t.device.type in ("lazy", "meta"),
                    # We need a way to test if a tensor is batched but there
                    # is no official APi to do it
                    # torch._C._is_batched(t),
                ]
            ):
                # TODO: sparse should support meta
                # NB technically to('meta') does work but our logging
                # instrumentation will see the meta conversions and the
                # tests all break so we just exclude this.  In any case
                # the to conversion isn't really right anyhow.
                self.miss += 1
                return NotImplemented

You could pdb this code to verify that my theory is true.

So, how to fix? As a start I'd try adding a carveout for XLA here. But there might be other stuff that breaks, not sure...