ts-plus / typescript

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TS+ `tailRec` is breaking proper code execution in certain scenarios

IMax153 opened this issue · comments

In @effect/core, we have a data type called Intervals which represent a set of Intervals used by Schedule to make scheduling decisions. Given that the Intervals data type is conceptually represented as a set, it has both union and intersect combinators.

In the ZIO implementation, both of these combinators make use of @tailrec to avoid overflowing the stack with recursion.

However, in the @effect/core implementation, use of @tsplus tailRec actually breaks proper execution of this method.

Example:

Given two Intervals:

Interval 1: ["2022-08-01T12:20:43.409Z - 2022-08-02T00:00:00.000Z"]
Interval 2: ["2022-08-03T00:00:00.000Z - 2022-08-04T00:00:00.000Z"]

the result of Intervals.union should be:

Union: [
  "2022-08-01T12:20:43.409Z - 2022-08-02T00:00:00.000Z",
  "2022-08-03T00:00:00.000Z - 2022-08-04T00:00:00.000Z"
]

However, when @tsplus tailRec is used, the result is:

Union: ["2022-08-01T12:19:39.942Z - 2022-08-02T00:00:00.000Z"]

Removing the TS+ tailRec annotation returns the proper output:

Union: [
  "2022-08-01T12:20:43.409Z - 2022-08-02T00:00:00.000Z",
  "2022-08-03T00:00:00.000Z - 2022-08-04T00:00:00.000Z"
]
Uncompiled Code

/**
 * Produces the union of this set of intervals and the specified set of intervals.
 *
 * @tsplus pipeable-operator effect/core/io/Schedule/Intervals ||
 * @tsplus static effect/core/io/Schedule/Intervals.Aspects union
 * @tsplus pipeable effect/core/io/Schedule/Intervals union
 */
export function union(that: Intervals) {
  return (self: Intervals): Intervals => {
    if (that.intervals.isNil()) {
      return self
    }
    if (self.intervals.isNil()) {
      return that
    }
    // const { head: left, tail: lefts } = self.intervals
    // const { head: right, tail: rights } = that.intervals
    if (self.intervals.head.startMillis < that.intervals.head.startMillis) {
      return unionLoop(self.intervals.tail, that.intervals, self.intervals.head, List.nil())
    }
    return unionLoop(self.intervals, that.intervals.tail, that.intervals.head, List.nil())
  }
}

/**
 * @tsplus tailRec
 */
function unionLoop(
  self: List<Interval>,
  that: List<Interval>,
  interval: Interval,
  acc: List<Interval>
): Intervals {
  switch (self._tag) {
    case "Nil": {
      switch (that._tag) {
        case "Nil": {
          return Intervals(acc.prepend(interval).reverse)
        }
        case "Cons": {
          // const { head: right, tail: rights } = that
          if (interval.endMillis < that.head.startMillis) {
            return unionLoop(List.nil(), that.tail, that.head, acc.prepend(interval))
          }
          return unionLoop(
            List.nil(),
            that.tail,
            Interval(interval.startMillis, that.head.endMillis),
            acc
          )
        }
      }
    }
    case "Cons": {
      switch (that._tag) {
        case "Nil": {
          // const { head: left, tail: lefts } = self
          if (interval.endMillis < self.head.startMillis) {
            return unionLoop(self.tail, List.nil(), self.head, acc.prepend(interval))
          }
          return unionLoop(
            self.tail,
            List.nil(),
            Interval(interval.startMillis, self.head.endMillis),
            acc
          )
        }
        case "Cons": {
          // const { head: left, tail: lefts } = self
          // const { head: right, tail: rights } = that
          if (self.head.startMillis < that.head.startMillis) {
            if (interval.endMillis < self.head.startMillis) {
              return unionLoop(self.tail, that, self.head, acc.prepend(interval))
            }
            return unionLoop(
              self.tail,
              that,
              Interval(interval.startMillis, self.head.endMillis),
              acc
            )
          }
          if (interval.endMillis < that.head.startMillis) {
            return unionLoop(self, that.tail, that.head, acc.prepend(interval))
          }
          return unionLoop(
            self,
            that.tail,
            Interval(interval.startMillis, that.head.endMillis),
            acc
          )
        }
      }
    }
  }
}

Compiled Code (No TS+ tailRec)

/**
* Produces the union of this set of intervals and the specified set of intervals.
*
* @tsplus pipeable-operator effect/core/io/Schedule/Intervals ||
* @tsplus static effect/core/io/Schedule/Intervals.Aspects union
* @tsplus pipeable effect/core/io/Schedule/Intervals union
*/
function union_1(that) {
  return (self) => {
      if (tsplus_module_3.isNil(that.intervals)) {
          return self;
      }
      if (tsplus_module_3.isNil(self.intervals)) {
          return that;
      }
      // const { head: left, tail: lefts } = self.intervals
      // const { head: right, tail: rights } = that.intervals
      if (self.intervals.head.startMillis < that.intervals.head.startMillis) {
          return unionLoop(self.intervals.tail, that.intervals, self.intervals.head, tsplus_module_3.nil());
      }
      return unionLoop(self.intervals, that.intervals.tail, that.intervals.head, tsplus_module_3.nil());
  };
}
function unionLoop(self, that, interval, acc) {
  switch (self._tag) {
      case "Nil": {
          switch (that._tag) {
              case "Nil": {
                  return make_1(tsplus_module_5.reverse(tsplus_module_4.prepend_(acc, interval)));
              }
              case "Cons": {
                  // const { head: right, tail: rights } = that
                  if (interval.endMillis < that.head.startMillis) {
                      return unionLoop(tsplus_module_3.nil(), that.tail, that.head, tsplus_module_4.prepend_(acc, interval));
                  }
                  return unionLoop(tsplus_module_3.nil(), that.tail, tsplus_module_6.fromStartEndMillis(interval.startMillis, that.head.endMillis), acc);
              }
          }
      }
      case "Cons": {
          switch (that._tag) {
              case "Nil": {
                  // const { head: left, tail: lefts } = self
                  if (interval.endMillis < self.head.startMillis) {
                      return unionLoop(self.tail, tsplus_module_3.nil(), self.head, tsplus_module_4.prepend_(acc, interval));
                  }
                  return unionLoop(self.tail, tsplus_module_3.nil(), tsplus_module_6.fromStartEndMillis(interval.startMillis, self.head.endMillis), acc);
              }
              case "Cons": {
                  // const { head: left, tail: lefts } = self
                  // const { head: right, tail: rights } = that
                  if (self.head.startMillis < that.head.startMillis) {
                      if (interval.endMillis < self.head.startMillis) {
                          return unionLoop(self.tail, that, self.head, tsplus_module_4.prepend_(acc, interval));
                      }
                      return unionLoop(self.tail, that, tsplus_module_6.fromStartEndMillis(interval.startMillis, self.head.endMillis), acc);
                  }
                  if (interval.endMillis < that.head.startMillis) {
                      return unionLoop(self, that.tail, that.head, tsplus_module_4.prepend_(acc, interval));
                  }
                  return unionLoop(self, that.tail, tsplus_module_6.fromStartEndMillis(interval.startMillis, that.head.endMillis), acc);
              }
          }
      }
  }
}

Compiled Code (With TS+ tailRec)

/**
* Produces the union of this set of intervals and the specified set of intervals.
*
* @tsplus pipeable-operator effect/core/io/Schedule/Intervals ||
* @tsplus static effect/core/io/Schedule/Intervals.Aspects union
* @tsplus pipeable effect/core/io/Schedule/Intervals union
*/
function union_1(that) {
  return (self) => {
      if (tsplus_module_3.isNil(that.intervals)) {
          return self;
      }
      if (tsplus_module_3.isNil(self.intervals)) {
          return that;
      }
      // const { head: left, tail: lefts } = self.intervals
      // const { head: right, tail: rights } = that.intervals
      if (self.intervals.head.startMillis < that.intervals.head.startMillis) {
          return unionLoop(self.intervals.tail, that.intervals, self.intervals.head, tsplus_module_3.nil());
      }
      return unionLoop(self.intervals, that.intervals.tail, that.intervals.head, tsplus_module_3.nil());
  };
}
/**
* @tsplus tailRec
*/
function unionLoop(self, that, interval, acc) {
  var self_1 = self, that_1 = that, interval_1 = interval, acc_1 = acc;
  var self_2 = self, that_2 = that, interval_2 = interval, acc_2 = acc;
  while (1) {
      switch (self_1._tag) {
          case "Nil": {
              switch (that_1._tag) {
                  case "Nil": {
                      return make_1(tsplus_module_5.reverse(tsplus_module_4.prepend_(acc, interval)));
                  }
                  case "Cons": {
                      // const { head: right, tail: rights } = that
                      if (interval_1.endMillis < that_1.head.startMillis) {
                          self_2 = tsplus_module_3.nil();
                          that_2 = that_1.tail;
                          interval_2 = that_1.head;
                          acc_2 = tsplus_module_4.prepend_(acc_1, interval_1);
                          self_1 = self_2;
                          that_1 = that_2;
                          interval_1 = interval_2;
                          acc_1 = acc_2;
                          continue;
                      }
                      self_2 = tsplus_module_3.nil();
                      that_2 = that_1.tail;
                      interval_2 = tsplus_module_6.fromStartEndMillis(interval_1.startMillis, that_1.head.endMillis);
                      acc_2 = acc_1;
                      self_1 = self_2;
                      that_1 = that_2;
                      interval_1 = interval_2;
                      acc_1 = acc_2;
                      continue;
                  }
              }
          }
          case "Cons": {
              switch (that_1._tag) {
                  case "Nil": {
                      // const { head: left, tail: lefts } = self
                      if (interval_1.endMillis < self_1.head.startMillis) {
                          self_2 = self_1.tail;
                          that_2 = tsplus_module_3.nil();
                          interval_2 = self_1.head;
                          acc_2 = tsplus_module_4.prepend_(acc_1, interval_1);
                          self_1 = self_2;
                          that_1 = that_2;
                          interval_1 = interval_2;
                          acc_1 = acc_2;
                          continue;
                      }
                      self_2 = self_1.tail;
                      that_2 = tsplus_module_3.nil();
                      interval_2 = tsplus_module_6.fromStartEndMillis(interval_1.startMillis, self_1.head.endMillis);
                      acc_2 = acc_1;
                      self_1 = self_2;
                      that_1 = that_2;
                      interval_1 = interval_2;
                      acc_1 = acc_2;
                      continue;
                  }
                  case "Cons": {
                      // const { head: left, tail: lefts } = self
                      // const { head: right, tail: rights } = that
                      if (self_1.head.startMillis < that_1.head.startMillis) {
                          if (interval_1.endMillis < self_1.head.startMillis) {
                              self_2 = self_1.tail;
                              that_2 = that_1;
                              interval_2 = self_1.head;
                              acc_2 = tsplus_module_4.prepend_(acc_1, interval_1);
                              self_1 = self_2;
                              that_1 = that_2;
                              interval_1 = interval_2;
                              acc_1 = acc_2;
                              continue;
                          }
                          self_2 = self_1.tail;
                          that_2 = that_1;
                          interval_2 = tsplus_module_6.fromStartEndMillis(interval_1.startMillis, self_1.head.endMillis);
                          acc_2 = acc_1;
                          self_1 = self_2;
                          that_1 = that_2;
                          interval_1 = interval_2;
                          acc_1 = acc_2;
                          continue;
                      }
                      if (interval_1.endMillis < that_1.head.startMillis) {
                          self_2 = self_1;
                          that_2 = that_1.tail;
                          interval_2 = that_1.head;
                          acc_2 = tsplus_module_4.prepend_(acc_1, interval_1);
                          self_1 = self_2;
                          that_1 = that_2;
                          interval_1 = interval_2;
                          acc_1 = acc_2;
                          continue;
                      }
                      self_2 = self_1;
                      that_2 = that_1.tail;
                      interval_2 = tsplus_module_6.fromStartEndMillis(interval_1.startMillis, that_1.head.endMillis);
                      acc_2 = acc_1;
                      self_1 = self_2;
                      that_1 = that_2;
                      interval_1 = interval_2;
                      acc_1 = acc_2;
                      continue;
                  }
              }
          }
      }
  }
}

Mmmmm the transformation we use is not safe when mutable things are in place....

Ah we weren't transforming the arguments of the base case, so it was using the original function arguments in the return Intervals(acc.prepend(interval).reverse) part. #217 should fix it.

This did indeed fix the issue - thank you @0x706b!