mananshah99 / gunpowder

Data loading DAG for Greentea.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

gunpowder

A data-loading, training, and processing DAG for greentea.

Based on PyGreentea by William Grisaitis, Fabian Tschopp, and Srini Turaga.

First steps

In gunpowder, you assemble a training pipeline as a directed acyclic graph (DAG) of batch providers. Everything starts with a "source", a batch provider with no inputs:

source =
    Hdf5Source(
            'example.hdf',
            raw_dataset='volumes/raw',
            gt_dataset='volumes/labels/neuron_ids',
            gt_mask_dataset='volumes/labels/mask')

Every batch provider can be asked for a batch via a request (e.g., shape, offset, which kind of volumes to provide). Starting from a source, you can chain batch providers to build a DAG. When a non-source batch provider is asked for a batch, it passes the request on upstream (i.e., towards a source) to receive a batch, possibly modifies it, and passes the batch on downstream.

As an example, this scheme allows the modelling of on-the-fly data augmentation as a batch provider:

augment =
    ElasticAugment(
        control_point_spacing=[4,40,40],
        jitter_sigma=[0,2,2],
        rotation_interval=[0,math.pi/2.0])

When augment is asked for a batch via a request, the request is automatically changed to request an upstream batch large enough to perform the elastic augmentation seamlessly.

Another example is the random selection of locations inside a source:

random =
    RandomLocation()

Training itself is modelled as a batch provider. It takes a batch, performs one training iteration, and adds the current prediction and loss to the batch:

solver_parameters = SolverParameters()
# set solver parameters (network, learning rate, optimizer, etc.)
train =
    Train(solver_parameters, use_gpu=0)

Putting it together, a very simple pipeline for training 1000 iterations would be

pipeline = source + random + augment + train

request = BatchRequest()
request.add_volume_request(VolumeType.RAW, (84,268,268))
request.add_volume_request(VolumeType.GT_LABELS, (56,56,56))
request.add_volume_request(VolumeType.GT_MASK, (56,56,56))

with build(pipeline) as p:
  for i in range(1000):
    p.request_batch(request)

Note that we use a BatchRequest object to communicate downstream the requirements for a batch. In the example, we are interested in batches of certain sizes (fitting the network we want to train) with raw data, ground-truth labels, and a mask.

For a full working example (with many more batch providers), see the example directory.

Writing Custom Batch Providers

The simplest batch provider is a BatchFilter, which has only one upstream provider. To create a new one, subclass it and override prepare and/or process:

class ExampleFilter(BatchFilter):

  def prepare(self, request):
    pass

  def process(self, batch, request):
    pass

prepare and process will be called in an alternating fashion. prepare is called first, when a BatchRequest is passed upstream through the filter. Your filter is free to change the request in any way it needs to, for example, by increasing the requested sizes. After prepare, process will be called with a batch going downstream, which is the upstream's response to the request you modified in prepare. In process, your filter should make all necessary changes to the batch and ensure it meets the original downstream request earlier communicated to prepare (given as request parameter in process for convenience).

For an example of a batch filter changing both the spec going upstream and the batch going downstream, see ElasticAugment.

About

Data loading DAG for Greentea.


Languages

Language:Python 99.6%Language:Makefile 0.4%