mbalassi / tensorflowonsparkdemo

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TensorFlow on Spark demo

This demo shows how to use TensorFlow to run distributed training on a CDH cluster. It is based on https://github.com/yahoo/TensorFlowOnSpark/wiki/GetStarted_YARN, with the following differences:

  • A Conda env is used, rather than building Python and TensorFlow from scratch.
  • RDMA is not used, so vanilla TensorFlow is all that is needed.

Further information about using Conda is available from this blog post.

There are also useful tips for running TensorFlow on Spark in this blog post.

Installation

First create a Conda env. This will take a few minutes.

./create-tf-conda-env.sh

Next, install TensorFlowOnSpark:

git clone https://github.com/yahoo/TensorFlowOnSpark
pushd TensorFlowOnSpark/src
zip -r ../tfspark.zip *
popd

Data preparation

Download the MNIST data, then turn it into a suitable format for TensorFlow, and store in HDFS.

./prepare-mnist-data.sh
./convert-mnist-data.sh

Check that the output files were created:

hadoop fs -ls -R mnist/csv

Training

Train the model by running:

./train-mnist-dist.sh

You can see the exact mnist_spark.py command that was run by looking in the YARN logs; see lower right link here:

Running train-mnist-dist.sh

args: Namespace(batch_size=100, cluster_size=4, epochs=1, format='csv', images='mnist/csv/train/images', labels='mnist/csv/train/labels', mode='train', model='mnist_model', output='predictions', rdma=False, readers=1, steps=1000, tensorboard=False)

The log also gives details of the worker nodes started and whether they are workers or the parameter server.

2017-05-04 08:02:53,142 INFO (MainThread-60229) All TFSparkNodes started
2017-05-04 08:02:53,143 INFO (MainThread-60229) {'addr': '/tmp/pymp-_O_ESR/listener-8ojX3h', 'task_index': 0, 'port': 40479, 'authkey': 'A/\xb7\x12\xd81Nz\xae\xed\xb9\xc8\xcdk\xbe\xf0', 'worker_num': 1, 'host': 'bottou05.sjc.cloudera.com', 'ppid': 14391, 'job_name': 'worker', 'tb_pid': 0, 'tb_port': 0}
2017-05-04 08:02:53,143 INFO (MainThread-60229) {'addr': '/tmp/pymp-JHTvXb/listener-3bzG21', 'task_index': 2, 'port': 32911, 'authkey': '\xc3\xb2J\xf1^2E\xb8\x86\xf4\xfa:;\x05\xd8\xac', 'worker_num': 3, 'host': 'bottou05.sjc.cloudera.com', 'ppid': 14395, 'job_name': 'worker', 'tb_pid': 0, 'tb_port': 0}
2017-05-04 08:02:53,143 INFO (MainThread-60229) {'addr': '/tmp/pymp-mPraGu/listener-5goUGe', 'task_index': 1, 'port': 38436, 'authkey': '\x9fP\x1dQ\xe9\x8cH\xc4\xab4FMtm\xc1\xcf', 'worker_num': 2, 'host': 'bottou05.sjc.cloudera.com', 'ppid': 14393, 'job_name': 'worker', 'tb_pid': 0, 'tb_port': 0}
2017-05-04 08:02:53,143 INFO (MainThread-60229) {'addr': ('bottou05.sjc.cloudera.com', 39106), 'task_index': 0, 'port': 33978, 'authkey': 'r\t\xbf\x95\xbc\xbbGA\x99\xd2\xc2\xb1\x1b7\xcfs', 'worker_num': 0, 'host': 'bottou05.sjc.cloudera.com', 'ppid': 14392, 'job_name': 'ps', 'tb_pid': 0, 'tb_port': 0}

About


Languages

Language:Shell 94.4%Language:Python 5.6%