numba / numba

NumPy aware dynamic Python compiler using LLVM

Home Page:https://numba.pydata.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Support for Advanced Numpy dtypes

OyiboRivers opened this issue · comments

Feature request

Hey Numba team,

Numba currently has limited support for advanced NumPy dtypes.
While it handles basic dtypes, such as integers or floats, it fails to compile functions if advanced dtypes are used as function arguments.

import numpy as np
from numba import njit

def make_zeros_np(shape, dtype):
    return np.zeros(shape, dtype)

@njit
def make_zeros_nb(shape, dtype):
    return np.zeros(shape, dtype)

# *********************************
# basic dtype
# *********************************
shape = (2, 3)
dtype = np.dtype('<i8')

print(make_zeros_np(shape, dtype))
# [[0 0 0]
#  [0 0 0]]

print(make_zeros_nb(shape, dtype))
# [[0 0 0]
#  [0 0 0]]

# *********************************
# advanced dtype
# *********************************
shape = (2,)
dtype = np.dtype(('<i8', (3,)))

print(make_zeros_np(shape, dtype))
# [[0 0 0]
#  [0 0 0]]

print(make_zeros_nb(shape, dtype))
# TypingError: No implementation of function Function(<built-in function zeros>) found for signature:
# zeros(UniTuple(int64 x 1), dtype(nestedarray(int64, (3,))))

The functionality could be useful to implement functions like "numpy.fromiter(iter, dtype)".
Having an advanced dtype would allow reading iterables of 1-D arrays from a generator function to create 2D-arrays.
np.fromiter(iter, dtype=np.dtype(('<i8', (3,))))

Have a great day!

Noting from discussion in the triage meeting that NumPy 2.0 also has some dtype extension mechanisms that might help us here - worth revisiting once NumPy 2.0 support is added to Numba.