mengjin001 / dask-tensorflow

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Dask-Tensorflow

Start TensorFlow clusters from Dask

Example

Given a Dask cluster

from dask.distributed import Client
client = Client('scheduler-address:8786')

Get a TensorFlow cluster, specifying groups by name

from dask_tensorflow import start_tensorflow
tf_spec, dask_spec = start_tensorflow(client, ps=2, worker=4)

>>> tf_spec
{'worker': ['192.168.1.100:2222', '192.168.1.101:2222',
            '192.168.1.102:2222', '192.168.1.103:2222'],
 'ps': ['192.168.1.104:2222', '192.168.1.105:2222']}

This creates a tensorflow.train.Server on each Dask worker and sets up a Queue for data transfer on each worker. These are accessible directly as tensorflow_server and tensorflow_queue attributes on the workers.

More Complex Workflow

Typically then we set up long running Dask tasks that get these servers and participate in general TensorFlow compuations.

from dask.distributed import worker_client

def ps_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        tf_server.join()

ps_tasks = [client.submit(ps_function, workers=worker, pure=False)
            for worker in dask_spec['ps']]

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server

        # ... use tensorflow as desired ...

worker_tasks = [client.submit(worker_function, workers=worker, pure=False)
                for worker in dask_spec['worker']]

One simple and flexible approach is to have these functions block on queues and feed them data from dask arrays, dataframes, etc.

def worker_function(self):
    with worker_client() as c:
        tf_server = c.worker.tensorflow_server
        queue = c.worker.tensorflow_queue

        while not stopping_condition():
            batch = queue.get()
            # train with batch

And then dump blocks of numpy and pandas dataframes to these queues

from distributed.worker_client import get_worker
def dump_batch(batch):
    worker = get_worker()
    worker.tensorflow_queue.put(batch)


import dask.dataframe as dd
df = dd.read_csv('hdfs:///path/to/*.csv')
# clean up dataframe as necessary
partitions = df.to_delayed()  # delayed pandas dataframes
client.map(dump_batch, partitions)

About

License:BSD 3-Clause "New" or "Revised" License


Languages

Language:Python 100.0%