Add rewrites to lift/flatten `Subtensor`s applied to `IncSubtensor`s
brandonwillard opened this issue · comments
Brandon T. Willard commented
The following illustrates the missing rewrite/optimization:
import aesara
import aesara.tensor as at
# The proposed rewrite should also work when `A` is an array
A = 1
B = at.set_subtensor(at.zeros((11, 2))[:1], A)[:5]
aesara.dprint(B)
# Subtensor{:int64:} [id A]
# |IncSubtensor{Set;:int64:} [id B]
# | |Alloc [id C]
# | | |TensorConstant{0.0} [id D]
# | | |TensorConstant{11} [id E]
# | | |TensorConstant{2} [id F]
# | |TensorConstant{1} [id G]
# | |ScalarConstant{1} [id H]
# |ScalarConstant{5} [id I]
f_B = aesara.function([], B, mode="FAST_RUN")
# As we can see, no rewrites have been applied, so we're allocating an
# unnecessarily large array (i.e. with shape (11, 2) instead of (5, 2)):
aesara.dprint(f_B)
# Subtensor{:int64:} [id A] 2
# |IncSubtensor{InplaceSet;:int64:} [id B] 1
# | |Alloc [id C] 0
# | | |TensorConstant{0.0} [id D]
# | | |TensorConstant{11} [id E]
# | | |TensorConstant{2} [id F]
# | |TensorConstant{1} [id G]
# | |ScalarConstant{1} [id H]
# |ScalarConstant{5} [id I]
f_B()
# array([[1., 1.],
# [0., 0.],
# [0., 0.],
# [0., 0.],
# [0., 0.]])
Aside from being a generally good optimization to have, it would also simplify/obviate all the logic here in save_mem_new_scan
—and perhaps other places as well.