tf-encrypted / tf-encrypted

A Framework for Encrypted Machine Learning in TensorFlow

Home Page:https://tf-encrypted.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

AssertionError When running Server

TammamT opened this issue · comments

in the Examples provide in "/examples/notebooks/serving/"
Running the server code, the last step results in the following error

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In [11], line 4
      1 def step_fn():
      2     print("Next")
----> 4 server.run(num_steps=1, step_fn=step_fn)

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/serving/queues.py:56, in QueueServer.run(self, num_steps, step_fn)
     54 if num_steps is not None:
     55     for _ in range(num_steps):
---> 56         self.run_step()
     57         if step_fn is not None:
     58             step_fn()

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File /tmp/__autograph_generated_filegrqv9dy7.py:11, in outer_factory.<locals>.inner_factory.<locals>.tf__computation_step()
      9 with ag__.FunctionScope('computation_step', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
     10     x = ag__.converted_call(ag__.ld(self).input_queue.dequeue, (), None, fscope)
---> 11     y = ag__.converted_call(ag__.ld(computation_fn), (ag__.ld(x),), None, fscope)
     12     ag__.converted_call(ag__.ld(self).output_queue.enqueue, (ag__.ld(y),), None, fscope)

File /tmp/__autograph_generated_fileg8341l89.py:12, in outer_factory.<locals>.inner_factory.<locals>.tf__computation(x)
     10 try:
     11     do_return = True
---> 12     retval_ = (ag__.ld(x) * ag__.ld(x))
     13 except:
     14     do_return = False

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:1727, in PondTensor.__mul__(self, other)
   1726 def __mul__(self, other):
-> 1727     return self.prot.mul(self, other)

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/protocol.py:117, in memoize.<locals>.cache_nodes(self, *args, **kwargs)
    114 if cached_result is not None:
    115     return cached_result
--> 117 result = func(self, *args, **kwargs)
    119 nodes[node_key] = result
    120 return result

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:997, in Pond.mul(self, x, y)
    994 @memoize
    995 def mul(self, x, y):
    996     x, y = self.lift(x, y)
--> 997     return self.dispatch("mul", x, y)

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:1444, in Pond.dispatch(self, base_name, container, *args, **kwargs)
   1442 func = getattr(container, func_name, None)
   1443 if func is not None:
-> 1444     return func(self, *args, **kwargs)  # pylint: disable=not-callable
   1445 raise TypeError(
   1446     ("Don't know how to {}: " "{}").format(
   1447         base_name, [type(arg) for arg in args]
   1448     )
   1449 )

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:2922, in _mul_private_private(prot, x, y)
   2920 assert isinstance(x, PondPrivateTensor), type(x)
   2921 assert isinstance(y, PondPrivateTensor), type(y)
-> 2922 return prot.mul(prot.mask(x), prot.mask(y))

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:986, in Pond.mask(self, x)
    983     return x_masked
    985 if isinstance(x, PondPrivateTensor):
--> 986     x_masked = _mask_private(self, x)
    988 else:
    989     raise TypeError("Don't know how to mask {}".format(type(x)))

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:4071, in _mask_private(prot, x)
   4068         with tf.device(prot.server_1.device_name):
   4069             alpha_on_1 = alpha0 + alpha1
-> 4071 return PondMaskedTensor(
   4072     prot,
   4073     x,
   4074     a,
   4075     a0,
   4076     a1,
   4077     alpha_on_0,
   4078     alpha_on_1,
   4079     x.is_scaled,
   4080 )

File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:2035, in PondMaskedTensor.__init__(self, prot, unmasked, a, a0, a1, alpha_on_0, alpha_on_1, is_scaled)
   2033 assert isinstance(unmasked, PondPrivateTensor)
   2034 assert a.device == prot.triple_source.producer.device_name
-> 2035 assert a0.device == prot.server_0.device_name
   2036 assert alpha_on_0.device == prot.server_0.device_name
   2037 assert a1.device == prot.server_1.device_name

AssertionError: in user code:

    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/serving/queues.py", line 43, in computation_step  *
        y = computation_fn(x)
    File "/tmp/ipykernel_1412/1491841022.py", line 5, in computation  *
        return x * x
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 1727, in __mul__
        return self.prot.mul(self, other)
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/protocol.py", line 117, in cache_nodes
        result = func(self, *args, **kwargs)
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 997, in mul
        return self.dispatch("mul", x, y)
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 1444, in dispatch
        return func(self, *args, **kwargs)  # pylint: disable=not-callable
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 2922, in _mul_private_private
        return prot.mul(prot.mask(x), prot.mask(y))
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 986, in mask
        x_masked = _mask_private(self, x)
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 4071, in _mask_private
        return PondMaskedTensor(
    File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 2035, in __init__
        assert a0.device == prot.server_0.device_name

    AssertionError: 

Code is running on WSL Ubuntu2004, with Python 3.8.15, tf-encrypted 0.9.0, tensorflow 2.9.3

When upgrade TFE to based on TF2, we find it's easy to forget assign op to a device, which is a serious security issue.
So we add a device check to make sure every share in the right device before construct a private tensor. But we only give a full test for ABY3 protocol after adding that check, since we mainly focus on ABY3 protocol. Thank you for opening this issue, we will fix this in the next commit. For now, if you don't care about security, you could just comment these checks.

Thank you, will wait for the next commit