Copy files under "template" directory
get_dataset
- returns dictionary of Chainer datasets.
- Dictionary keys must include "train" for training dataset.
- Dictionary keys must include "test" for test dataset if you need.
- Dictionary keys may include "validation" for validation while training.
- example:
def get_dataset(): train, test = chainer.datasets.get_mnist() validation, train = chainer.datasets.split_dataset_random(train, 5000) return { 'train': train, 'validation': validation, 'test': test, }
- returns dictionary of Chainer datasets.
calculate_metrics
- returns loss and other metrics for each optimizer.
- example:
with other metorics:def calculate_metrics(net, batch): x, t = batch y = net(x) loss = F.softmax_cross_entropy return loss
with two optimizers:def calculate_metrics(net, batch): x, t = batch y = net(x) loss = F.softmax_cross_entropy(y, t) accuracy = F.accuracy(y, t) return { 'loss': loss, 'accuracy': accuracy }
def calculate_metrics(nets, batch): gen = nets['gen'] dis = nets['dis'] # ... calculate loss loss_gen = ... loss_dis = ... return { 'gen': loss_gen, 'dis': loss_dis, } # or return { 'gen': {}, 'dis': loss_dis, }
make_eval_func
:- makes evaluation function for Chainer
Evaluator
extension.
- makes evaluation function for Chainer
predict
:- returns prediction of your neural network.
Implement your neural network class that extends chainer.Chain
.
Usually you don't have to modify training script except adding chainer training extensions if you need.
Usually you don't have to modify prediction script.
Implement your evaluation script.
Configuration independent of each training/prediction.
train
(object): training configuration.dump_graph
(str): Name of the root fordump_graph
extension.plot_report
(object):- key: File name for figure file without extension.
- value(array of str):
y_keys
forPlotReport
extension.
print_report
(array of str):entries
forPrintReport
extension.max_value_trigger
(string or null): The best model is saved when the value associated with this key string becomes maximum. (e.g. "validation/main/accuracy")min_value_trigger
(string or null): The best model is saved when the value associated with this key string becomes minimum. (e.g. "validation/main/loss") You can use only eithermax_value_trigger
ormin_value_trigger
.
Configuration depends on each training/prediction.
batch_size
(int): Mini batch size for training/prediction.epoch
(int): The number of epochs for training.gpu
(int): Default GPU device index (negative value indicates CPU). This can be changed with command line option.dataset
(object):parameter
(object, array, single value or null): Parameter forget_dataset
.null
indicates no parameter.
network
(object): Network information.- key: network identifier.
- value(object):
class
(str): Network class.parameter
(object, array, single value or null): Parameter for network constructor.optimizer
(object): Network optimizer information.class
(str): Optimizer class.parameter
(object, array, single value or null): Parameter for optimizer constructor.hook
(array of objects):class
(str): Hook class.parameter
(object, array, single value or null): Parameter for hook constructor.
output_dir
(str): Output directory path for training.