hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.

Home Page:https://hidet.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[NotImplementedError] operator.lshift

HanGuo97 opened this issue · comments

Describe the Problem
BackendCompilerFailed: hidet_backend raised NotImplementedError: The following modules/functions are not supported by hidet yet:
operator.lshift

Thanks @HanGuo97 for reporting this.

Hi @Aalanli, could you help add this operator when you have a chance?

Fixed in #371

Thanks for the quick fix!

Does that mean I could simply build from source, and this would work via torch.compile interface?

I am not sure about the particulars of your model, but this script worked for me (when built from source):

import hidet
import torch

def test(a):
    return a << 3
t = torch.compile(test, backend='hidet')
t(torch.randn(3, 5, device='cuda').to(torch.int64))

Hi @Aalanli,
Thank you so much for your help! The patch you provided worked perfectly for me.

I wanted to follow up with a more detailed example. This particular issue caused several errors for me. Some of these were fixable following the patch you provided, such as pow, bitwise_and, and torch.Tensor.max.

However, there is one error that I’m having difficulty fixing without additional knowledge of the codebase. Here’s the error message:

ValueError: Unknown data type: uint8x4, candidates...

I’ve attached a simplified code that shows this error, and possibly a few others. I would appreciate any help you can offer in resolving this issue. Thank you in advance for your time!

Code
import math
import torch
import jaxtyping
from typing import Tuple

DEFAULT_CONTAINER_NUM_BITS = 8
FloatTensorType = jaxtyping.Float[torch.Tensor, "..."]
UInt8TensorType = jaxtyping.UInt8[torch.Tensor, "..."]
BinaryTensorType = jaxtyping.Bool[torch.Tensor, "..."]
PackedBinaryTensorType = jaxtyping.UInt8[torch.Tensor, "..."]


def from_binary(tensor: BinaryTensorType, num_bits: int) -> UInt8TensorType:
    if tensor.dtype != torch.bool:
        raise TypeError
    if tensor.shape[-1] != num_bits:
        raise ValueError
    if num_bits > 8:
        raise NotImplementedError
    mask = torch.tensor([2], dtype=torch.float32, device=tensor.device) ** torch.arange(
        num_bits - 1, -1, -1,
        dtype=torch.float32,
        device=tensor.device)
    mask = mask.to(dtype=torch.uint8)
    tensor = tensor.to(dtype=torch.uint8)
    output = torch.sum(mask * tensor, dim=-1)
    output = output.to(dtype=torch.uint8)
    return output


def unpack_uint8_into_bool(
    packed_tensor: PackedBinaryTensorType,
    padding_length: int,
) -> BinaryTensorType:
    if packed_tensor.ndim != 1:
        raise ValueError
    if packed_tensor.dtype != torch.uint8:
        raise TypeError
    # Some constants
    packed_dtype = torch.uint8
    packed_num_bits = torch.iinfo(packed_dtype).bits

    # [1, packed_num_bits]
    bits = torch.tensor(
        1,
        dtype=packed_dtype,
        device=packed_tensor.device)
    bits = bits << torch.arange(
        packed_num_bits,
        dtype=packed_dtype,
        device=packed_tensor.device)
    bits = torch.unsqueeze(
        bits,
        dim=0)
    unpacked_tensor = torch.unsqueeze(
        packed_tensor,
        dim=-1)
    unpacked_tensor = unpacked_tensor & bits
    unpacked_tensor = unpacked_tensor > 0
    unpacked_tensor = unpacked_tensor.to(dtype=torch.bool)
    unpacked_tensor = unpacked_tensor.view(-1)
    if padding_length > 0:
        unpacked_tensor = unpacked_tensor[:-padding_length]
    return unpacked_tensor


@torch.compile(fullgraph=True, backend="hidet")
def unpack_integer_tensors(
    packed_tensor: PackedBinaryTensorType,
    padding_length: int,
    num_bits: int,
    shape: Tuple[int, ...],
) -> UInt8TensorType:
    packed_size = (
        (math.prod(shape) * num_bits + padding_length) /
        DEFAULT_CONTAINER_NUM_BITS)
    if packed_tensor.shape != (packed_size,):
        raise ValueError

    # [tensor.numel() x num_bits / 8]
    packed_tensor = packed_tensor.contiguous()
    # [tensor.numel() x num_bits]
    binary_tensor = unpack_uint8_into_bool(
        packed_tensor=packed_tensor,
        padding_length=padding_length)
    # [*tensor.shape, num_bits]
    binary_tensor = binary_tensor.view(
        *shape, num_bits)
    return from_binary(
        tensor=binary_tensor,
        num_bits=num_bits)


num_bits = 8
shape = torch.Size([1024, 256, 1])
unpack_integer_tensors(
    torch.randint(
        2 ** 8,
        size=(shape.numel(),),
        dtype=torch.uint8,
        device="cuda"),
    padding_length=0,
    num_bits=num_bits,
    shape=shape,
)

No problem!
For me, pr #372 works on the code provided.

Amazing, thanks a ton! Will give this a try soon.