openxla / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Trying to run gpt2 results in `ValueError: unimplemented array format conversion from format: ?`

navdeepkk-polymagelabs opened this issue · comments

What happened?

I was trying to lower GPT2 from HuggingFace and ran into the following error while importing MLIR.

Steps to reproduce your issue

Consider the following python script trying to convert GPT2 to IREE IR.

import torch
from transformers import GPT2Tokenizer, GPT2Model, GPT2Config
from shark_turbine import aot


def initialize_gpt2():
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    text = "Replace me by any text you'd like."
    encoded_input = tokenizer(text, return_tensors="pt")
    return encoded_input


config = GPT2Config(return_dict=False)
model = GPT2Model(config)
model.eval()

# Have to create a list of tensors to supply to `aot.export`
inputs = list(initialize_gpt2().values())

with torch.no_grad():

    class CustomModule(aot.CompiledModule):
        compute = aot.jittable(model.forward)

        def main(
            self,
            input_ids=aot.abstractify(inputs[0]),
            attention_mask=aot.abstractify(inputs[1]),
        ):
            return self.compute(
                input_ids=input_ids, attention_mask=attention_mask
            )

    export_output = aot.export(CustomModule, *inputs)

Executing this results in the following error:

Traceback (most recent call last):
  File "/home1/navdeep/work/projects/polyblocks-compiler/external/iree/examples/pytorch/gpt2_iree.py", line 34, in <module>
    export_output = aot.export(CustomModule, *inputs)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/exporter.py", line 199, in export
    cm = Exported(context=context, import_to="import")
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/compiled_module.py", line 538, in __new__
    do_export(proc_def)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/compiled_module.py", line 535, in do_export
    trace.trace_py_func(invoke_with_self)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/support/procedural/tracer.py", line 121, in trace_py_func
    return_py_value = _unproxy(py_f(*self.proxy_posargs, **self.proxy_kwargs))
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/compiled_module.py", line 516, in invoke_with_self
    return proc_def.callable(self, *args, **kwargs)
  File "/home1/navdeep/work/projects/polyblocks-compiler/external/iree/examples/pytorch/gpt2_iree.py", line 30, in main
    return self.compute(
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/support/procedural/base.py", line 137, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/support/procedural/tracer.py", line 137, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/aot/builtins/jittable.py", line 239, in resolve_call
    fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/dynamo/importer.py", line 266, in import_stateless_graph
    node_importer.import_nodes(g.nodes)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/dynamo/importer.py", line 501, in import_nodes
    self._import_torch_op_overload(loc, node, target)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/dynamo/importer.py", line 669, in _import_torch_op_overload
    self._import_argument(loc, node.args[i], parameter.type)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/dynamo/importer.py", line 718, in _import_argument
    self._v[(arg, 0)] = self._import_literal(obj)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/dynamo/importer.py", line 747, in _import_literal
    return converter(py_value, self, self._cc)
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/dynamo/importer.py", line 948, in <lambda>
    lambda arg, gni, cc: _make_vtensor_literal_op(
  File "/home1/navdeep/.venv/base/lib/python3.10/site-packages/shark_turbine/dynamo/importer.py", line 907, in _make_vtensor_literal_op
    elements_attr = DenseElementsAttr.get(bytes, signless=False)
ValueError: unimplemented array format conversion from format: ?

Is this a type conversion issue from shark_turbine. Any help with this will be appraciated.

Thanks!

What component(s) does this issue relate to?

No response

Version information

I am using the following python packages:

iree-compiler            20231113.707
iree-runtime             20231113.707
shark-turbine            0.9.2
torch                    2.1.1

Additional context

No response

Looks like an unimplemented tensor element type case. See format codes here: https://docs.python.org/3/library/struct.html

'?' is a bool

This is trying to emit some kind of bool constant and failing.

Thanks. Yes I was able to see from the stack trace that it is trying to create a DenseElementsAttr from a bool tensor. Should this case also be handled by shark turbine? I am using the model as is from Hugging face.