roastduck / FreeTensor

A language and compiler for irregular tensor programs.

Home Page:https://roastduck.github.io/FreeTensor/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Remove Writes incorrectly removed some writes used later

Blealtan opened this issue · comments

import ir
with ir.VarDef("u", (64, 2, ), "float64", "input", "cpu") as u: 
    with ir.VarDef("spmv", (64, 2, ), "float64", "output", "cpu") as spmv:
        with ir.VarDef("tmp", (2, ), "float64", "cache", "cpu") as tmp: 
            with ir.For("i$1", 0, 2) as i_361:
                tmp[i_361] = 0
            with ir.VarDef("lazyrow__value", (2, 2, ), "float32", "cache", "cpu") as lazyrow____value: 
                lazyrow____value[0, 0] = -1
                lazyrow____value[0, 1] = 0
                lazyrow____value[1, 0] = 0
                lazyrow____value[1, 1] = -1
                with ir.VarDef("product", (2, ), "float64", "cache", "cpu") as product: 
                    with ir.For("i$2", 0, 2) as i_362:
                        tmp[i_362] += ((lazyrow____value[i_362, 0] * u[0, 0]) + (lazyrow____value[i_362, 1] * u[0, 1]))
                    lazyrow____value[0, 0] = 1
                    lazyrow____value[0, 1] = 0
                    lazyrow____value[1, 0] = 0
                    lazyrow____value[1, 1] = 1
                with ir.VarDef("product", (2, ), "float64", "cache", "cpu") as product: 
                    with ir.For("i$2", 0, 2) as i_362:
                        tmp[i_362] += ((lazyrow____value[i_362, 0] * u[1, 0]) + (lazyrow____value[i_362, 1] * u[1, 1]))
                with ir.For("i$4", 0, 2) as i_364:
                    spmv[0, i_364] = tmp[i_364]

f = ir.Func("fuck", ["u"], [("spmv", ir.parseDType("float64"))], ir.pop_ast())
print(f)
print(ir.remove_writes(f))

It outputs

func(u) -> spmv: f64 {
  [in] [CPU] u: f64[64, 2] {
    [out] [CPU] spmv: f64[64, 2] {
      [cache] [CPU] tmp: f64[2] {
        for i$1 in 0 : 2 : 1 {
          tmp[i$1] = 0
        }
        [cache] [CPU] lazyrow__value: f32[2, 2] {
          lazyrow__value[0, 0] = -1
          lazyrow__value[0, 1] = 0
          lazyrow__value[1, 0] = 0
          lazyrow__value[1, 1] = -1
          [cache] [CPU] product: f64[2] {
            for i$2 in 0 : 2 : 1 {
              tmp[i$2] = (tmp[i$2] + ((lazyrow__value[i$2, 0] * u[0, 0]) + (lazyrow__value[i$2, 1] * u[0, 1])))
            }
            lazyrow__value[0, 0] = 1
            lazyrow__value[0, 1] = 0
            lazyrow__value[1, 0] = 0
            lazyrow__value[1, 1] = 1
          }
          [cache] [CPU] product: f64[2] {
            for i$2 in 0 : 2 : 1 {
              tmp[i$2] = (tmp[i$2] + ((lazyrow__value[i$2, 0] * u[1, 0]) + (lazyrow__value[i$2, 1] * u[1, 1])))
            }
          }
          for i$4 in 0 : 2 : 1 {
            spmv[0, i$4] = tmp[i$4]
          }
        }
      }
    }
  }
}

func(u) -> spmv: f64 {
  [in] [CPU] u: f64[64, 2] {
    [out] [CPU] spmv: f64[64, 2] {
      [cache] [CPU] tmp: f64[2] {
        [cache] [CPU] lazyrow__value: f32[2, 2] {
          for i$2 in 0 : 2 : 1 {
            tmp[i$2] = (0 + ((lazyrow__value[i$2, 0] * u[0, 0]) + (lazyrow__value[i$2, 1] * u[0, 1])))
          }
          lazyrow__value[0, 0] = 1
          lazyrow__value[0, 1] = 0
          lazyrow__value[1, 0] = 0
          lazyrow__value[1, 1] = 1
          for i$2 in 0 : 2 : 1 {
            tmp[i$2] += ((lazyrow__value[i$2, 0] * u[1, 0]) + (lazyrow__value[i$2, 1] * u[1, 1]))
          }
        }
        for i$4 in 0 : 2 : 1 {
          spmv[0, i$4] = tmp[i$4]
        }
      }
    }
  }
}

and makes the early accesses to lazyrow__value[i$2, 0] uninitialized.

A smaller test script

with ir.VarDef("u", (64,), "float64", "input", "cpu") as u:
    with ir.VarDef("y", (2,), "float64", "output", "cpu") as y:
        with ir.VarDef("tmp", (2,), "float64", "cache", "cpu") as tmp:
            with ir.For("i", 0, 2) as i:
                tmp[i] = 0
            with ir.VarDef("A", (2,), "float32", "cache", "cpu") as A:
                A[0] = 0
                A[1] = 1
                with ir.For("i", 0, 2) as i:
                    tmp[i] += ((A[i] * u[0, 0]) + (A[i] * u[0, 1]))
                A[0] = 2
                A[1] = 3
                with ir.For("i", 0, 2) as i:
                    tmp[i] += ((A[i] * u[1, 0]) + (A[i] * u[1, 1]))
                with ir.For("i", 0, 2) as i:
                    y[i] = tmp[i]
ast = ir.pop_ast()
print(ast)
ast = ir.remove_writes(ast)
print(ast)

Fixed in c593a50