Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.

Home Page:https://cg.cs.tsinghua.edu.cn/jittor/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is there an efficient implementation of block diag function?

liylo opened this issue · comments

Here is a naive implementatin, do you have better ones ?
`def block_diag(*tensors):
# Calculate the total shape of the block diagonal matrix
total_rows = sum(tensor.shape[0] if tensor.ndim > 0 else 1 for tensor in tensors)
total_cols = sum(tensor.shape[1] if tensor.ndim == 2 else 1 for tensor in tensors)

# Initialize the block diagonal matrix with zeros
block_matrix = jt.zeros((total_rows, total_cols), dtype=tensors[0].dtype)

current_row = 0
current_col = 0

# Place each tensor in the block diagonal matrix
for tensor in tensors:
    rows = tensor.shape[0] if tensor.ndim > 0 else 1
    cols = tensor.shape[1] if tensor.ndim == 2 else 1

    if tensor.ndim == 0:
        block_matrix[current_row, current_col] = tensor
    elif tensor.ndim == 1:
        for i in range(cols):
            block_matrix[current_row, current_col + i] = tensor[i]
    else:
        for i in range(rows):
            for j in range(cols):
                block_matrix[current_row + i, current_col + j] = tensor[i, j]

    current_row += rows
    current_col += cols

return block_matrix`