Is there an efficient implementation of block diag function?
liylo opened this issue · comments
Liu Yilong commented
Here is a naive implementatin, do you have better ones ?
`def block_diag(*tensors):
# Calculate the total shape of the block diagonal matrix
total_rows = sum(tensor.shape[0] if tensor.ndim > 0 else 1 for tensor in tensors)
total_cols = sum(tensor.shape[1] if tensor.ndim == 2 else 1 for tensor in tensors)
# Initialize the block diagonal matrix with zeros
block_matrix = jt.zeros((total_rows, total_cols), dtype=tensors[0].dtype)
current_row = 0
current_col = 0
# Place each tensor in the block diagonal matrix
for tensor in tensors:
rows = tensor.shape[0] if tensor.ndim > 0 else 1
cols = tensor.shape[1] if tensor.ndim == 2 else 1
if tensor.ndim == 0:
block_matrix[current_row, current_col] = tensor
elif tensor.ndim == 1:
for i in range(cols):
block_matrix[current_row, current_col + i] = tensor[i]
else:
for i in range(rows):
for j in range(cols):
block_matrix[current_row + i, current_col + j] = tensor[i, j]
current_row += rows
current_col += cols
return block_matrix`