tlc-pack / tvm-tensorir

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[bug] BufferSlice do not support __sub__

yzh119 opened this issue · comments

When I tried to write a loop whose upper bound and lower bound were both specified by a BufferSlice (indptr[i]):

@tvm.script.tir
def spmm(w_indptr: ty.handle, w_indices: ty.handle, w: ty.handle, x: ty.handle, y: ty.handle) -> None:
    m = tir.var('int32')
    k = tir.var('int32')
    nnz = tir.var('int32')
    n = tir.var('int32')
    indptr = tir.match_buffer(w_indptr, [m], 'int32')
    indices = tir.match_buffer(w_indices, [nnz], 'int32')
    W = tir.match_buffer(w, [nnz], 'float32')
    X = tir.match_buffer(x, [k, n], 'float32')
    Y = tir.match_buffer(y, [m, n], 'float32')
    for i, k in tir.grid(m, n):
        for j in tir.serial(0, indptr[i + 1] - indptr[i]):
            with tir.block([m, n, tir.reduce_axis(0, indptr[i + 1] - indptr[i])], 'spmm') as [ii, kk, jj]:
                tir.bind(ii, i)
                tir.bind(jj, j + indptr[i])
                tir.bind(kk, k)
                with tir.init():
                    Y[ii, kk] = tir.float32(0.)
                Y[ii, kk] = Y[ii, kk] + W[jj] * X[indices[jj], kk]

I met the following bug:

error: unsupported operand type(s) for -: 'BufferSlice' and 'BufferSlice'
 --> test_tir_schedule_normalize.py:49:18
    |  
 49 |          for j in tir.serial(indptr[i], indptr[i + 1]):
    |                   ^^^^^^^^^^                           
Traceback (most recent call last):
  File "/home/zihao/tvm-tensorir/python/tvm/script/utils.py", line 132, in call_with_error_reporting
    return func(*args, **kwargs)
  File "/home/zihao/tvm-tensorir/python/tvm/script/scope_handler.py", line 501, in serial
    return self.create_loop(begin, end, ForKind.SERIAL, annotations=annotations, span=span)
  File "/home/zihao/tvm-tensorir/python/tvm/script/scope_handler.py", line 471, in create_loop
    extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
TypeError: unsupported operand type(s) for -: 'BufferSlice' and 'BufferSlice'

However, if I write the loop in another way

@tvm.script.tir
def spmm(w_indptr: ty.handle, w_indices: ty.handle, w: ty.handle, x: ty.handle, y: ty.handle) -> None:
    m = tir.var('int32')
    k = tir.var('int32')
    nnz = tir.var('int32')
    n = tir.var('int32')
    indptr = tir.match_buffer(w_indptr, [m], 'int32')
    indices = tir.match_buffer(w_indices, [nnz], 'int32')
    W = tir.match_buffer(w, [nnz], 'float32')
    X = tir.match_buffer(x, [k, n], 'float32')
    Y = tir.match_buffer(y, [m, n], 'float32')
    for i, k in tir.grid(m, n):
        for j in tir.serial(0, indptr[i + 1] - indptr[i]):
            with tir.block([m, n, tir.reduce_axis(0, indptr[i + 1] - indptr[i])], 'spmm') as [ii, kk, jj]:
                tir.bind(ii, i)
                tir.bind(jj, indices[j + indptr[i]])
                tir.bind(kk, k)
                with tir.init():
                    Y[ii, kk] = tir.float32(0.)
                Y[ii, kk] = Y[ii, kk] + W[jj] * X[jj, kk]

The program works.

I wonder do we support explicit conversion from BufferSlice to PrimExpr? (Considering BufferSlice is not necessarily a scalar, this should not always work).

@Hzfengsy @junrushao1994 @spectrometerHBH

That's a good question. BufferSlice is a python class, which does not have type conversion. Conversion from BufferSlice to PrimExpr happens during asobject (convert a python class to TVM object).

I have been thinking about this question for some time but do not have any answer. It would be great if you guys can share your ideas.

Can we consolidate this issue to #471?