AssertionError When running Server
TammamT opened this issue · comments
in the Examples provide in "/examples/notebooks/serving/"
Running the server code, the last step results in the following error
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In [11], line 4
1 def step_fn():
2 print("Next")
----> 4 server.run(num_steps=1, step_fn=step_fn)
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/serving/queues.py:56, in QueueServer.run(self, num_steps, step_fn)
54 if num_steps is not None:
55 for _ in range(num_steps):
---> 56 self.run_step()
57 if step_fn is not None:
58 step_fn()
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
File /tmp/__autograph_generated_filegrqv9dy7.py:11, in outer_factory.<locals>.inner_factory.<locals>.tf__computation_step()
9 with ag__.FunctionScope('computation_step', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
10 x = ag__.converted_call(ag__.ld(self).input_queue.dequeue, (), None, fscope)
---> 11 y = ag__.converted_call(ag__.ld(computation_fn), (ag__.ld(x),), None, fscope)
12 ag__.converted_call(ag__.ld(self).output_queue.enqueue, (ag__.ld(y),), None, fscope)
File /tmp/__autograph_generated_fileg8341l89.py:12, in outer_factory.<locals>.inner_factory.<locals>.tf__computation(x)
10 try:
11 do_return = True
---> 12 retval_ = (ag__.ld(x) * ag__.ld(x))
13 except:
14 do_return = False
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:1727, in PondTensor.__mul__(self, other)
1726 def __mul__(self, other):
-> 1727 return self.prot.mul(self, other)
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/protocol.py:117, in memoize.<locals>.cache_nodes(self, *args, **kwargs)
114 if cached_result is not None:
115 return cached_result
--> 117 result = func(self, *args, **kwargs)
119 nodes[node_key] = result
120 return result
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:997, in Pond.mul(self, x, y)
994 @memoize
995 def mul(self, x, y):
996 x, y = self.lift(x, y)
--> 997 return self.dispatch("mul", x, y)
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:1444, in Pond.dispatch(self, base_name, container, *args, **kwargs)
1442 func = getattr(container, func_name, None)
1443 if func is not None:
-> 1444 return func(self, *args, **kwargs) # pylint: disable=not-callable
1445 raise TypeError(
1446 ("Don't know how to {}: " "{}").format(
1447 base_name, [type(arg) for arg in args]
1448 )
1449 )
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:2922, in _mul_private_private(prot, x, y)
2920 assert isinstance(x, PondPrivateTensor), type(x)
2921 assert isinstance(y, PondPrivateTensor), type(y)
-> 2922 return prot.mul(prot.mask(x), prot.mask(y))
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:986, in Pond.mask(self, x)
983 return x_masked
985 if isinstance(x, PondPrivateTensor):
--> 986 x_masked = _mask_private(self, x)
988 else:
989 raise TypeError("Don't know how to mask {}".format(type(x)))
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:4071, in _mask_private(prot, x)
4068 with tf.device(prot.server_1.device_name):
4069 alpha_on_1 = alpha0 + alpha1
-> 4071 return PondMaskedTensor(
4072 prot,
4073 x,
4074 a,
4075 a0,
4076 a1,
4077 alpha_on_0,
4078 alpha_on_1,
4079 x.is_scaled,
4080 )
File ~/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py:2035, in PondMaskedTensor.__init__(self, prot, unmasked, a, a0, a1, alpha_on_0, alpha_on_1, is_scaled)
2033 assert isinstance(unmasked, PondPrivateTensor)
2034 assert a.device == prot.triple_source.producer.device_name
-> 2035 assert a0.device == prot.server_0.device_name
2036 assert alpha_on_0.device == prot.server_0.device_name
2037 assert a1.device == prot.server_1.device_name
AssertionError: in user code:
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/serving/queues.py", line 43, in computation_step *
y = computation_fn(x)
File "/tmp/ipykernel_1412/1491841022.py", line 5, in computation *
return x * x
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 1727, in __mul__
return self.prot.mul(self, other)
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/protocol.py", line 117, in cache_nodes
result = func(self, *args, **kwargs)
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 997, in mul
return self.dispatch("mul", x, y)
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 1444, in dispatch
return func(self, *args, **kwargs) # pylint: disable=not-callable
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 2922, in _mul_private_private
return prot.mul(prot.mask(x), prot.mask(y))
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 986, in mask
x_masked = _mask_private(self, x)
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 4071, in _mask_private
return PondMaskedTensor(
File "/home/wsl/miniconda3/envs/tfe/lib/python3.8/site-packages/tf_encrypted/protocol/pond/pond.py", line 2035, in __init__
assert a0.device == prot.server_0.device_name
AssertionError:
Code is running on WSL Ubuntu2004, with Python 3.8.15, tf-encrypted 0.9.0, tensorflow 2.9.3
When upgrade TFE to based on TF2, we find it's easy to forget assign op to a device, which is a serious security issue.
So we add a device check to make sure every share in the right device before construct a private tensor. But we only give a full test for ABY3 protocol after adding that check, since we mainly focus on ABY3 protocol. Thank you for opening this issue, we will fix this in the next commit. For now, if you don't care about security, you could just comment these checks.
Thank you, will wait for the next commit