MGensheimer / nnet-survival

Discrete-Time Survival Model for Neural Networks

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

network architecture

XinyeYang opened this issue · comments

Hello, I've been looking at the newly added PyTorch version of the code and I have some confusion: In the nnet_survival_pytorch_example.ipynb file, this part of the code:
class SimpleSurvivalModel(nn.Module):
def init(self, n_predictors, n_intervals):
super(SimpleSurvivalModel, self).init()
self.lin = nn.Linear(n_predictors, n_intervals)
#self.weight = nn.Parameter(torch.zeros(n_intervals))
#self.bias = nn.Parameter(torch.zeros(n_intervals))
def forward(self, x):
return F.sigmoid(self.lin(x))
#return F.sigmoid(x @ self.weight + self.bias)
is just a simple linear neural network, without involving convolutional layers. This is different from the architecture described in the paper. Can I modify this network architecture myself?

Hi,
Thanks for the question. Right, it is a toy example in the notebook. For a multi-layer model you would add more layers. For the final layer, you would apply a sigmoid to get the output for each follow-up time period into the 0-1 range which represents the probability of surviving that time period. I'll try to make a more complicated example to demonstrate but it will take me some time.