AssertionError: Have you provided bias as a kwarg? If so, please remove bias=.
mustansarsaeed opened this issue · comments
Mustansar Saeed commented
Hi, I am using PySyft 0.2.x
and trying to build training_plan
for CNN
model. My Model is as follows:
Model
import syft as sy
from syft.frameworks.torch.nn import Conv2d, max_pool2d
class FemnistNet(nn.Module):
def __init__(self):
super(FemnistNet, self).__init__()
self.conv1 = Conv2d(1, 32, kernel_size=5, stride=1, padding=0, bias=True)
self.pool1 = nn.MaxPool2d(2, stride=2, )
self.conv2 = Conv2d(32, 64, kernel_size=5, stride=1, padding=0, bias=True)
self.pool2 = nn.MaxPool2d(2, stride=2)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512 ,62)
def forward(self, x):
x = x.view(-1, 1, 28, 28)
x = self.conv1(x)
x = th.nn.functional.relu(x)
x = self.pool1(x)
x=self.conv2(x)
x = th.nn.functional.relu(x)
x = self.pool2(x)
x = x.flatten(start_dim=1)
x = self.fc1(x)
l1_activations = th.nn.functional.relu(x)
x = self.fc2(l1_activations)
x = x.softmax()
return x, l1_activations
Training Plan
@sy.func2plan()
def training_plan(X, y, batch_size, lr, model_params):
model.train()
# inject params into model
set_model_params(model, model_params)
logits, activations = model.forward(X)
loss = cross_entropy_with_logits(logits, y, batch_size)
# backprop
loss.backward()
updated_params = [
naive_sgd(param, lr=lr)
for param in model_params
]
# accuracy
pred = th.argmax(logits, dim=1)
target = th.argmax(y, dim=1)
acc = pred.eq(target).sum()/ batch_size
return (
loss,
acc,
logits,
None,
*updated_params,
)
Building training plan
num = 50
X = th.tensor((dataX), dtype=th.float) ##784 sized 1D array. would be reshaped to 1, 28, 28
y = nn.functional.one_hot(th.tensor(dataY), 62) ##62-length sized one hot vectors
lr = th.tensor([0.0003]) ##0.0003 learning rate
batch_size = th.tensor([float(num)]) ##20 is our batch size
loss, acc, logits, target, *updated_params = training_plan.build(X[0:num], y[0:num], batch_size, lr, model_params, trace_autograd=True)
Error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-196-b1aa83563586> in <module>
6 batch_size = th.tensor([float(num)]) ##20 is our batch size
7
----> 8 loss, acc, logits, target, *updated_params = training_plan.build(X[0:num], y[0:num], batch_size, lr, model_params, trace_autograd=True)
9 # updated_params
10 loss
/usr/local/lib/python3.7/dist-packages/syft-0.2.9-py3.7.egg/syft/execution/plan.py in build(self, trace_autograd, *args)
273 framework_kwargs[f_name] = wrap_framework_func(self.role)
274
--> 275 results = self.forward(*args, **framework_kwargs)
276
277 # Register inputs in role
<ipython-input-131-e49ab938819a> in training_plan(X, y, batch_size, lr, model_params)
12
13
---> 14 logits, activations = model.forward(X)
15 # print("Logits", logits[:1])
16 loss = cross_entropy_with_logits(logits, y, batch_size)
<ipython-input-193-80ed0341fc1a> in forward(self, x)
45 # print(x.shape)
46
---> 47 x = self.conv1(x)
48 x = th.nn.functional.relu(x)
49 # print("conv1 shape", x.shape)
/usr/local/lib/python3.7/dist-packages/torch-1.4.0-py3.7-linux-x86_64.egg/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
--> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
/usr/local/lib/python3.7/dist-packages/syft-0.2.9-py3.7.egg/syft/frameworks/torch/nn/conv.py in forward(self, input)
70
71 return conv2d(
---> 72 input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
73 )
74
/usr/local/lib/python3.7/dist-packages/syft-0.2.9-py3.7.egg/syft/frameworks/torch/nn/functional.py in conv2d(input, weight, bias, stride, padding, dilation, groups)
245 assert isinstance(
246 bias, sy.AdditiveSharingTensor
--> 247 ), "Have you provided bias as a kwarg? If so, please remove `bias=`."
248
249 locations = input.locations
AssertionError: Have you provided bias as a kwarg? If so, please remove `bias=`.
Can anyone please tell what is the issue?