js-ish / DOoC

Digtal Organoid On Chips

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DOoC

Train

import random
import torch
from torch import nn
import torch.optim as optim

from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig

from dooc import models, datasets

Regression

# Regression datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
ds = datasets.MutSmiReg(smi_tokenizer=tokenizer)
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
mutations = [[random.choice([0, 1]) for _ in range(3008)]] * 2
# mutations contains 0/1 encoding information of the genome
values = [0.85, 0.78]
mut_x, smi_tgt, out = ds(mutations, smiles, values)

# Regression train
model = models.MutSmiReg()
model.load_pretrained_ckpt(
    mut_ckpt='path/to/drugcell.pt',
    smi_ckpt='path/to/moltx.ckpt'
    )
mse_loss = nn.MSELoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-04,
    foreach=False
    )
optimizer.zero_grad()
pred = model(mut_x=mut_x, smi_tgt=smi_tgt)
loss = mse_loss(pred, out)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmireg.ckpt')

Pairwise

# Pairwise datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
ds = datasets.MutSmisPairwiseRank(smi_tokenizer=tokenizer)
smiles = [["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"],
          ["CC[N+](C)(C)Cc1ccccc1Br", "CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1"]]
mutations = [[random.choice([0, 1]) for _ in range(3008)]] * 2
# mutations contains 0/1 encoding information of the genome
values = [[0.85, 0.78]] * 2
mut_x, smi_tgt, out = ds(mutations, smiles, values)

# Pairwise train
model = models.MutSmisRank()
model.load_pretrained_ckpt(
    mut_ckpt='path/to/drugcell.pt',
    smi_ckpt='path/to/moltx.ckpt'
    )

# Pairwise loss
loss_func = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-04,
    foreach=False
    )

optimizer.zero_grad()
pred = model(mut_x=mut_x, smi_tgt=smi_tgt)
loss = loss_func(pred[:,0] - pred[:,1], out)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), '/path/to/mutsmipairwise.ckpt')

Listwise

# Listwise datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
ds = datasets.MutSmisListwiseRank(smi_tokenizer=tokenizer)
smiles = [["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br", "CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1"]] * 2
mutations = [[random.choice([0, 1]) for _ in range(3008)]] * 2
# mutations contains 0/1 encoding information of the genome
values = [[0.85, 0.78, 0.79]] * 2
mut_x, smi_tgt, out = ds(mutations, smiles, values)
mut_x, smi_tgt, out = mut_x.squeeze(0), smi_tgt.squeeze(0), out.squeeze(0)

# Listwise train
model = models.MutSmisRank()
model.load_pretrained_ckpt(
    mut_ckpt='path/to/drugcell.pt',
    smi_ckpt='path/to/moltx.ckpt'
    )

# Listwise loss
loss_func = dooc_list_loss.ListNetLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-04,
    foreach=False
    )

optimizer.zero_grad()
pred = model(mut_x=mut_x, smi_tgt=smi_tgt)
loss = loss_func(pred, out)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmilistwise.ckpt')

Inference

import random
from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig
from dooc import pipelines, models

# Regression
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
model = models.MutSmiReg()
model.load_ckpt('/path/to/mutsmireg.ckpt')
pipeline = pipelines.MutSmiReg(
    smi_tokenizer=tokenizer, model=model
    )
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mut=mutations, smi=smiles) # e.g. 0.85

# Rank
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
model = models.MutSmisRank()
model.load_ckpt('/path/to/mutsmirank.ckpt')
pipeline = pipelines.MutSmisRank(smi_tokenizer=tokenizer, model=model)
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br", "CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1"]
predict = pipeline(mut=mutations, smis=smiles) # e.g. ["CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1", "CC[N+](C)(C)Cc1ccccc1Br", "c1cccc1c"]

About

Digtal Organoid On Chips

License:Apache License 2.0


Languages

Language:Python 100.0%