Using the cox loss and methods with cutom model
SalvatoreRa opened this issue · comments
Great work,
I found your approach very interesting and I was trying to generalize it to different pytorch architectures
I wanted to test your approach with custom models and other pytorch model. The idea is to basically take a pytorch model (arbitrary architecture) and test the ability to predict survival.
for example, I wanted to test with a simple pytorch model.
let' s say:
- considering a simple pytorch loop with a generic pytorch model
- the idea is transforming in a model predicting survival
now, to better explain there is below:
- a simple code used for transforming the task in binary classification (I used the dataset you provided, just to create a code that works)
- the function from repository that I may think can be considered useful for the task
What I am trying to understand is, considering this case:
- how to modify a simple architecture for the survival prediction (I am not exactly sure how the last layer should be)
- how to incorporate in the training loop the loss. More broadly, how to modify the loop to use a pytorch model (which can different layers, different architectures and so on) for the survival task, using the loss you provide
taking this dataset and starting from your example:
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from lassonet import LassoNetCoxRegressor
from lassonet import plot_path
res_dir = './survival/'
X = np.genfromtxt(res_dir + "hnscc_x.csv", delimiter=",", skip_header=1)
y = np.genfromtxt(res_dir + "hnscc_y.csv", delimiter=",", skip_header=1)
this is a simple version of the approach modelling the survival as a simple binary classification approach:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, Dataset
import pandas as pd
import seaborn as sns
# creating a simple MLP
class FCNNC(nn.Module):
def __init__(self, input_size, constraint_size, hidden_size, num_classes):
super(FCNNC, self).__init__()
self.fc1 = nn.Linear(input_size, constraint_size)
self.fc2 = nn.Linear(constraint_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = torch.tanh(self.fc1(x))
x = torch.sigmoid(self.fc2(x))
x = self.fc3(x)
return x
# simple class for the dataset
class DataClassifier(Dataset):
def __init__(self, X_train, y_train):
self.X = torch.from_numpy(X_train.astype(np.float32))
self.y = torch.from_numpy(y_train).type(torch.LongTensor)
self.len = self.X.shape[0]
def __getitem__(self, index):
return self.X[index], self.y[index]
def __len__(self):
return self.len
# binary accuracy
def multi_acc(y_pred, y_test):
_, y_pred = torch.max(y_pred, dim = 1)
correct_pred = (y_pred == y_test).float()
acc = correct_pred.sum() / len(correct_pred)
acc = torch.round(acc * 100)
return acc
# transforming in binary classification
batch_size = 2048
X_train, X_test, Y_train, Y_test = train_test_split(X, y[:,1], random_state=0)
traindata = DataClassifier(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True)
valdata = DataClassifier(X_test,Y_test)
valloader = torch.utils.data.DataLoader(valdata, batch_size=X_test.shape[0], shuffle=False)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
model = FCNNC(X.shape[1],20,20,2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
n_epochs =1000
%matplotlib inline
# simple training loop to store results and plotting
accuracy_stats = {
'train': [],
"val": []
}
loss_stats = {
'train': [],
"val": []
}
for epoch in range(n_epochs):
running_loss = 0.0
train_epoch_acc = 0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
model.to(device)
# set optimizer to zero grad to remove previous epoch gradients
optimizer.zero_grad()
# forward propagation
outputs = model(inputs)
loss = criterion(outputs, labels)
acc = multi_acc(outputs, labels)
# backward propagation
loss.backward()
# optimize
optimizer.step()
running_loss += loss.item()
train_epoch_acc += acc.item()
with torch.no_grad():
val_epoch_loss = 0
val_epoch_acc = 0
model.eval()
for X_val_batch, y_val_batch in valloader:
X_val_batch = X_val_batch.to(device)
y_val_batch = y_val_batch.to(device)
y_val_pred = model(X_val_batch)
val_loss = criterion(y_val_pred, y_val_batch)
val_acc = multi_acc(y_val_pred, y_val_batch)
val_epoch_loss += val_loss.item()
val_epoch_acc += val_acc.item()
loss_stats['train'].append(running_loss/len(trainloader))
loss_stats['val'].append(val_epoch_loss/len(valloader))
accuracy_stats['train'].append(train_epoch_acc/len(trainloader))
accuracy_stats['val'].append(val_epoch_acc/len(valloader))
if epoch % 50 == True:
print(f'Epoch {epoch+0:03}: | Train Loss: {running_loss/len(trainloader):.5f} | Val Loss: {val_epoch_loss/len(valloader):.5f} | Train Acc: {train_epoch_acc/len(trainloader):.3f}| Val Acc: {val_epoch_acc/len(valloader):.3f}')
train_val_acc_df = pd.DataFrame.from_dict(accuracy_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
train_val_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
# Plot the dataframes
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20,7))
sns.lineplot(data=train_val_acc_df, x = "epochs", y="value", hue="variable", ax=axes[0]).set_title('Train-Val Accuracy/Epoch')
sns.lineplot(data=train_val_loss_df, x = "epochs", y="value", hue="variable", ax=axes[1]).set_title('Train-Val Loss/Epoch')
The idea starting from very simple example to transform a model in able to handle censored data
I was highlighting this code from your repository:
import torch
from sortedcontainers import SortedList
def log_substract(x, y):
"""log(exp(x) - exp(y))"""
return x + torch.log1p(-(y - x).exp())
def scatter_logsumexp(input, index, *, dim=-1, output_size=None):
"""Inspired by torch_scatter.logsumexp
Uses torch.scatter_reduce for performance
"""
max_value_per_index = scatter_reduce(
input, dim=dim, index=index, output_size=output_size, reduce="amax"
)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = input - max_per_src_element
sum_per_index = scatter_reduce(
recentered_scores.exp(),
dim=dim,
index=index,
output_size=output_size,
reduce="sum",
)
return max_value_per_index + sum_per_index.log()
class CoxPHLoss(torch.nn.Module):
"""Loss for CoxPH model. """
allowed = ("breslow", "efron")
def __init__(self, method):
super().__init__()
assert method in self.allowed, f"Method must be one of {self.allowed}"
self.method = method
def forward(self, log_h, y):
log_h = log_h.flatten()
durations, events = y.T
# sort input
durations, idx = durations.sort(descending=True)
log_h = log_h[idx]
events = events[idx]
event_ind = events.nonzero().flatten()
# numerator
log_num = log_h[event_ind].mean()
# logcumsumexp of events
event_lcse = torch.logcumsumexp(log_h, dim=0)[event_ind]
# number of events for each unique risk set
_, tie_inverses, tie_count = torch.unique_consecutive(
durations[event_ind], return_counts=True, return_inverse=True
)
# position of last event (lowest duration) of each unique risk set
tie_pos = tie_count.cumsum(axis=0) - 1
# logcumsumexp by tie for each event
event_tie_lcse = event_lcse[tie_pos][tie_inverses]
if self.method == "breslow":
log_den = event_tie_lcse.mean()
elif self.method == "efron":
# based on https://bydmitry.github.io/efron-tensorflow.html
# logsumexp of ties, duplicated within tie set
tie_lse = scatter_logsumexp(log_h[event_ind], tie_inverses, dim=0)[
tie_inverses
]
# multiply (add in log space) with corrective factor
aux = torch.ones_like(tie_inverses)
aux[tie_pos[:-1] + 1] -= tie_count[:-1]
event_id_in_tie = torch.cumsum(aux, dim=0) - 1
discounted_tie_lse = (
tie_lse
+ torch.log(event_id_in_tie)
- torch.log(tie_count[tie_inverses])
)
# denominator
log_den = log_substract(event_tie_lcse, discounted_tie_lse).mean()
# loss is negative log likelihood
return log_den - log_num
def concordance_index(risk, time, event):
"""
O(n log n) implementation of https://square.github.io/pysurvival/metrics/c_index.html
"""
assert len(risk) == len(time) == len(event)
n = len(risk)
order = sorted(range(n), key=time.__getitem__)
past = SortedList()
num = 0
den = 0
for i in order:
num += len(past) - past.bisect_right(risk[i])
den += len(past)
if event[i]:
past.add(risk[i])
return num / den
Thank you very much
Salvatore
Any model can use the CoxPHLoss. The loss will be:
criterion = CoxPHLoss()
loss = criterion(model(X_train[batch]), y_train[batch])
To transform the data, please look at this example: https://github.com/lasso-net/lassonet/blob/master/examples/cox_experiments.py
Thank you for your reply,
If I have understood correctly, for a dataset one should do:
X = np.genfromtxt(path_x, delimiter=",", skip_header=1)
y = np.genfromtxt(path_y, delimiter=",", skip_header=1)
X = preprocessing.StandardScaler().fit(X).transform(X)
For instance, this should work if the dataset is save as the provided data: hnscc
Once done that, you can split in training and test set:
'''
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=random_state, stratify=y[:, 1], test_size=0.20
)
'''
Should I do anything for the data loader? and for the NN architecture (like which should be the last layer)?
For y you should have both a duration and a boolean column for events. For the data loader, just use mini batches. The last layer should just output a real number, so a Linear layer is good.
Thank you,
I have done so:
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from lassonet import LassoNetCoxRegressor
from lassonet import plot_path
res_dir = './survival/'
X = np.genfromtxt(res_dir + "hnscc_x.csv", delimiter=",", skip_header=1)
y = np.genfromtxt(res_dir + "hnscc_y.csv", delimiter=",", skip_header=1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, random_split, SubsetRandomSampler, ConcatDataset, Dataset
import pandas as pd
import seaborn as sns
class FCNNC(nn.Module):
def __init__(self, input_size, constraint_size, hidden_size, num_classes):
super(FCNNC, self).__init__()
self.fc1 = nn.Linear(input_size, constraint_size)
self.fc2 = nn.Linear(constraint_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
x = torch.tanh(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class DataClassifier(Dataset):
def __init__(self, X_train, y_train):
self.X = torch.from_numpy(X_train.astype(np.float32))
self.y = torch.from_numpy(y_train.astype(np.float32))
self.len = self.X.shape[0]
def __getitem__(self, index):
return self.X[index], self.y[index]
def __len__(self):
return self.len
batch_size = 200
X_train, X_test, Y_train, Y_test = train_test_split(X, y, random_state=0)
traindata = DataClassifier(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, shuffle=True)
valdata = DataClassifier(X_test,Y_test)
valloader = torch.utils.data.DataLoader(valdata, batch_size=X_test.shape[0], shuffle=False)
n_epochs =1000
%matplotlib inline
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = CoxPHLoss(method="breslow")
model = FCNNC(X.shape[1],20,20,1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
loss_stats = {
'train': [],
"val": []
}
for epoch in range(n_epochs):
running_loss = 0.0
train_epoch_acc = 0
for i, data in enumerate(trainloader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
model.to(device)
# set optimizer to zero grad to remove previous epoch gradients
optimizer.zero_grad()
# forward propagation
outputs = model(inputs)
#loss = criterion(outputs, labels)
loss =criterion(model(inputs), labels)
# backward propagation
loss.backward()
# optimize
optimizer.step()
running_loss += loss.item()
with torch.no_grad():
val_epoch_loss = 0
model.eval()
for X_val_batch, y_val_batch in valloader:
X_val_batch = X_val_batch.to(device)
y_val_batch = y_val_batch.to(device)
y_val_pred = model(X_val_batch)
val_loss =criterion(model(X_val_batch), y_val_batch)
val_epoch_loss += val_loss.item()
loss_stats['train'].append(running_loss/len(trainloader))
loss_stats['val'].append(val_epoch_loss/len(valloader))
if epoch % 50 == True:
print(f'Epoch {epoch+0:03}: | Train Loss: {running_loss/len(trainloader):.5f} | Val Loss: {val_epoch_loss/len(valloader):.5f}')
train_val_loss_df = pd.DataFrame.from_dict(loss_stats).reset_index().melt(id_vars=['index']).rename(columns={"index":"epochs"})
# Plot the dataframes
sns.lineplot(data=train_val_loss_df, x = "epochs", y="value", hue="variable").set_title('Train-Val Loss/Epoch')
It worked and loss is diminishing. As last question how you evaluate? How you should use with CI index? do you have some suggestions for evaluation?
This is the most basically implementation of neural network with PyTorch, but I think it could work with any custom network (and it is using the data you provide). Do you want I will organize the script as tutorial? could be useful?