apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Home Page:https://mxnet.apache.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

forward can't run parallelly using multi-gpus when custom operator using numpy

leoxiaobin opened this issue · comments

Description

I found after #6928 when you using numpy in custom operator, multi-gpus forwards is not running parallelly. But before #6928, multi-gpus forwards can run parallelly. We can use mxnet's image-classification example to reproduce it by replacing Softmax operator using custom's softmax.

Environment info

----------Python Info----------
('Version      :', '2.7.6')
('Compiler     :', 'GCC 4.8.4')
('Build        :', ('default', 'Oct 26 2016 20:30:19'))
('Arch         :', ('64bit', 'ELF'))
------------Pip Info-----------
('Version      :', '9.0.1')
('Directory    :', '/usr/local/lib/python2.7/dist-packages/pip-9.0.1-py2.7.egg/pip')
----------MXNet Info-----------
('Version      :', '0.10.0')
('Directory    :', '/data/home/xizhou/incubator-mxnet/python/mxnet')
Traceback (most recent call last):
  File "diagnose_new.py", line 171, in <module>
    check_mxnet()
  File "diagnose_new.py", line 113, in check_mxnet
    except FileNotFoundError:
NameError: global name 'FileNotFoundError' is not defined 
----------System Info----------
('Platform     :', 'Linux-3.13.0-132-generic-x86_64-with-Ubuntu-14.04-trusty')
('system       :', 'Linux')
('node         :', 'msravcg10')
('release      :', '3.13.0-132-generic')
('version      :', '#181-Ubuntu SMP Wed Sep 13 13:25:03 UTC 2017')
----------Hardware Info----------
('machine      :', 'x86_64')
('processor    :', 'x86_64')
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                32
On-line CPU(s) list:   0-31
Thread(s) per core:    2
Core(s) per socket:    8
Socket(s):             2
NUMA node(s):          2
Vendor ID:             GenuineIntel
CPU family:            6
Model:                 62
Stepping:              4
CPU MHz:               2600.000
BogoMIPS:              5188.77
Virtualization:        VT-x
L1d cache:             32K
L1i cache:             32K
L2 cache:              256K
L3 cache:              20480K
NUMA node0 CPU(s):     0-7,16-23
NUMA node1 CPU(s):     8-15,24-31
----------Network Test----------
Setting timeout: 10
Timing for MXNet: https://github.com/apache/incubator-mxnet, DNS: 0.0017 sec, LOAD: 1.1562 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0012 sec, LOAD: 0.4335 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.2028 sec, LOAD: 0.9514 sec.
Timing for Conda: https://repo.continuum.io/pkgs/free/, DNS: 0.0678 sec, LOAD: 0.4102 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.0684 sec, LOAD: 0.2063 sec.
Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [Errno 1] _ssl.c:510: error:14077410:SSL routines:SSL23_GET_SERVER_HELLO:sslv3 alert handshake failure>, DNS finished in 0.0689778327942 sec.

Package used (Python/R/Scala/Julia):
Python

Build info (Required if built from source)

Compiler (gcc/clang/mingw/visual studio):
gcc

MXNet commit hash:
ed19095

Minimum reproducible example

The custom's softmax operator, just in order to reproduce this issue, so I did no implement the backward.

class Softmax(mx.operator.CustomOp):
    def forward(self, is_train, req, in_data, out_data, aux):
        self.assign(out_data[0], req[0], mx.nd.softmax(mx.nd.array(in_data[0].asnumpy(), ctx=in_data[0].context)))

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], 0)

@mx.operator.register("softmax")
class SoftmaxProp(mx.operator.CustomOpProp):
    def __init__(self):
        super(SoftmaxProp, self).__init__(need_top_grad=False)

    def list_arguments(self):
        return ['data', 'label']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shape):
        data_shape = in_shape[0]
        label_shape = (in_shape[0][0],)
        output_shape = in_shape[0]
        return [data_shape, label_shape], [output_shape], []

    def infer_type(self, in_type):
        return in_type, [in_type[0]], []

    def create_operator(self, ctx, shapes, dtypes):
        return Softmax()

Steps to reproduce

  1. run original train_imagenet.py as baseline
python train_imagenet.py --benchmark 1 --batch-size 128 --gpus 0,1,2,3

Training speed is:

INFO:root:Epoch[0] Batch [20]   Speed: 217.27 samples/sec       accuracy=0.113467
INFO:root:Epoch[0] Batch [40]   Speed: 217.81 samples/sec       accuracy=1.000000
  1. run train_imagenet.py using custom softmax which using asnumpy function in the custom's operator.
python train_imagenet.py --benchmark 1 --batch-size 128 --gpus 0,1,2,3

Training speed is:

INFO:root:Epoch[0] Batch [20]   Speed: 114.91 samples/sec       accuracy=0.000000
INFO:root:Epoch[0] Batch [40]   Speed: 113.70 samples/sec       accuracy=0.000000

What have you tried to solve it?

I have used the mxnet build-in profiler to find more detail about the execution time
The original's version:
v0 12 0_softmax
using custom softmax's version:
v0 12 0_asnumpy
it can see, that when using custom operator, the forward procedures on multi-gpus are running sequentially not parallelly.

I have also tried mxnet's version before #6928, using custom softmax operator or not, the speed is almost the same.
original training speed using mxnet before #6928

INFO:root:Epoch[0] Batch [20]   Speed: 217.54 samples/sec       accuracy=0.232515
INFO:root:Epoch[0] Batch [40]   Speed: 214.66 samples/sec       accuracy=1.000000

using custom softmax using mxnet before #6928

INFO:root:Epoch[0] Batch [20]   Speed: 217.28 samples/sec       accuracy=0.000000
INFO:root:Epoch[0] Batch [40]   Speed: 213.57 samples/sec       accuracy=0.000000

We here at TuSimple also found this phenomenon. This is the bottleneck for large-scale training of DET models.
We also found that change CPU_WORKER number does not alleviate this.
A viable workaround is to rewrite memory-oriented operators in pure C++ by copying in_data from GPU and out_data back to GPU.

Now FExecType of CustomOp is kLocal https://github.com/apache/incubator-mxnet/blob/master/src/operator/custom/custom.cc#L404, which run on the scheduling thread without pushing to engine. Is this the reason why custom op cannot scale? Why kLocal is used for CustomOp in #6928? @piiswrong

I have encountered the same problem when training object detection. Can you fix this bug? Or give any alternative solution? @piiswrong

I have encountered the same issue when doing a Seq2Seq model. Any solution/workarounds?
@piiswrong

hi, @piiswrong @mli , it seems that many people encountered the same issue, is that a bug of mxnet?

I guess this is an artifact of GIL since the custom op code is in python. Design-wise, to circumvent the GIL one way I see is to parse the python construct and pass it to the backend, which is not easy to do.

For now, if you care about performance, please write the operator in c++ instead.

thanks, @szha , but the speed is normal using the same code before the pr #6928. Only after that pr, the issue happened.

FYI. A simple workaround for loss-type CustomOp is to comment out all calculations in forward and leave the assign only. This workaround gives a fully paralleled forward and an almost paralleled backward since the losses are firstly calculated during backward.

Recent pr makes the Ops before or after CustomOp run in parallel. But the CustomOp itself still runs sequentially. Anyone has clues why this happen?

entering 0: 1516522076.3805
exiting 0: 1516522076.6744
entering 1: 1516522076.6760
exiting 1: 1516522076.9477
entering 0: 1516522077.0904
exiting 0: 1516522077.3583
entering 1: 1516522077.3599
exiting 1: 1516522077.6237
entering 0: 1516522077.7664
exiting 0: 1516522078.0574
entering 1: 1516522078.0590
exiting 1: 1516522078.3297

A MCVE as following runs in a 2GPU setting:

import time
import mxnet as mx
import numpy as np


class DebugOperator(mx.operator.CustomOp):
    def __init__(self, **kwargs):
        super(DebugOperator, self).__init__()
        self.pos = kwargs.get("pos", None)

    def forward(self, is_train, req, in_data, out_data, aux):
        print("entering %d: %.4f" % (in_data[0][0].context.device_id, time.time()))
        time.sleep(0.1)
        self.assign(out_data[0], req[0], 0)
        print("exiting %d: %.4f" % (in_data[0][0].context.device_id, time.time()))

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], 0)


@mx.operator.register("Debug")
class DebugProp(mx.operator.CustomOpProp):
    def __init__(self, **kwargs):
        super(DebugProp, self).__init__(need_top_grad=False)
        self._kwargs = kwargs

    def list_arguments(self):
        return ['data']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shape):
        return in_shape, [(1, )]

    def create_operator(self, ctx, shapes, dtypes):
        return DebugOperator(**self._kwargs)


def get_symbol():
    data = mx.sym.var("data")
    label = mx.sym.var("softmax_label")
    proj = mx.sym.FullyConnected(data, num_hidden=1)
    debug = mx.sym.Custom(proj, op_type="Debug", name="debug")
    return mx.sym.Group([debug, label])


if __name__ == "__main__":
    gpus = [0, 1]
    sym = get_symbol()
    mod = mx.module.Module(sym, context=[mx.gpu(i) for i in gpus])
    mod.bind(data_shapes=[("data", (len(gpus), 1))], label_shapes=[("softmax_label", (len(gpus), 1))])
    data = mx.io.NDArrayIter(data=np.zeros((10000, 1)), label=np.zeros((10000, 1)), batch_size=len(gpus))
    mod.fit(data, num_epoch=1, eval_metric=mx.metric.Loss(output_names=["debug_output"]))

outputs are:

entering 1: 1516523993.4081
exiting 1: 1516523993.5086
entering 0: 1516523993.5088
exiting 0: 1516523993.6092
entering 1: 1516523993.6362
exiting 1: 1516523993.7368
entering 0: 1516523993.7369
exiting 0: 1516523993.8373
entering 1: 1516523993.8394
exiting 1: 1516523993.9398
entering 0: 1516523993.9400
exiting 0: 1516523994.0404
entering 1: 1516523994.0634
exiting 1: 1516523994.1692
entering 0: 1516523994.1694
exiting 0: 1516523994.2698
entering 0: 1516523994.2750
exiting 0: 1516523994.3755
entering 1: 1516523994.3757
exiting 1: 1516523994.4761
entering 0: 1516523994.4873
exiting 0: 1516523994.5877
entering 1: 1516523994.5879
exiting 1: 1516523994.6883
entering 0: 1516523994.6943
exiting 0: 1516523994.7948

with the latest code on master, this problem still exists.

Proposed Labels : "Python", "Distributed", "Ubuntu"

Does this problem still exist? There is a recent PR about this problem https://github.com/apache/incubator-mxnet/pull/9283/files that has been merged.

@leoxiaobin Does this issue still exist?

The issue still exist in lastest version 1.3.1 or 1.5.0 in master @sxjscience @vandanavk . Is there any workaround other than writing C++ layers? @piiswrong