aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.

Home Page:https://aesara.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Add rewrites to lift/flatten `Subtensor`s applied to `IncSubtensor`s

brandonwillard opened this issue · comments

The following illustrates the missing rewrite/optimization:

import aesara
import aesara.tensor as at


# The proposed rewrite should also work when `A` is an array
A = 1
B = at.set_subtensor(at.zeros((11, 2))[:1], A)[:5]

aesara.dprint(B)
# Subtensor{:int64:} [id A]
#  |IncSubtensor{Set;:int64:} [id B]
#  | |Alloc [id C]
#  | | |TensorConstant{0.0} [id D]
#  | | |TensorConstant{11} [id E]
#  | | |TensorConstant{2} [id F]
#  | |TensorConstant{1} [id G]
#  | |ScalarConstant{1} [id H]
#  |ScalarConstant{5} [id I]

f_B = aesara.function([], B, mode="FAST_RUN")

# As we can see, no rewrites have been applied, so we're allocating an
# unnecessarily large array (i.e. with shape (11, 2) instead of (5, 2)):
aesara.dprint(f_B)
# Subtensor{:int64:} [id A] 2
#  |IncSubtensor{InplaceSet;:int64:} [id B] 1
#  | |Alloc [id C] 0
#  | | |TensorConstant{0.0} [id D]
#  | | |TensorConstant{11} [id E]
#  | | |TensorConstant{2} [id F]
#  | |TensorConstant{1} [id G]
#  | |ScalarConstant{1} [id H]
#  |ScalarConstant{5} [id I]

f_B()
# array([[1., 1.],
#        [0., 0.],
#        [0., 0.],
#        [0., 0.],
#        [0., 0.]])

Aside from being a generally good optimization to have, it would also simplify/obviate all the logic here in save_mem_new_scan—and perhaps other places as well.