tf-encrypted / moose

Secure distributed dataflow framework for encrypted machine learning and data processing

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Make AbstractComputation nest-able

jvmncs opened this issue · comments

Functions written with the eDSL should be callable from within other computations, regardless of whether they've been wrapped with the pm.computation decorator. For example,

@pm.computation
def plus1(x: pm.Argument(alice, dtype=pm.float64):
  with alice:
    one = pm.constant(1, dtype=pm.float64)
    return pm.add(x, one)

@pm.computation
def alice_add():
  with alice:
    x = pm.constant(3, dtype=pm.float64)
    x_plus_one = plus1(x)
  return x_plus_one

if __name__ == "__main__":
  [...]
  runtime.set_default()
  val = alice_add()  # <-- will fail during tracing

When alice_add is called, current behavior would be the following:

  • inside a runtime context, alice_add.__call__ invokes trace(alice_add)
  • trace(alice_add) will then invoke plus1.__call__. in order for this call to succeed, plus1 will need to return an Expression to be used to trace the rest of alice_add.
  • however, since the default runtime context is not None, plus1 will be executed against the default runtime's evaluate_computation with arguments of type Expression
  • the rust runtime bindings will try to interpret these Expression pyobj's as Moose Values, which will fail with a TypeError because these are not concrete values.

One solution for the user is to just drop the pm.computation decorator from plus1, so that it returns Expression no matter what runtime context is around. But this makes it hard for users to use "standard library" computations if they are already decorated with AbstractComputation (which would likely often be the case).

I think the simplest solution here would be to do the following:

  • Inside pm.trace, temporarily unset the default runtime context, so that get_current_runtime returns None.
  • If AbstractComputation.__call__ is invoked without a runtime context (i.e. get_current_runtime returns None), invoke AbstractComputation.func.__call__. This invocation maps Expressions to Expressions, so tracing can proceed normally.
  • If AbstractComputation.__call__ is invoked inside a runtime context, invoke get_current_runtime().evaluate_computation(...) with the computation as usual

Some other options:

  • Allow for nesting runtime contexts and create a new "dummy" Runtime class whoseevaluate_computation simply forwards to AbstractComputation.func.__call__
  • Something "moose-ier", e.g. accommodate Expression conversion in Moose bindings and in this case execute symbolically, i.e. run computation against a SymbolicSession instead of against the AsyncSession in AsyncTestRuntime