Scaling Distributed Machine Learning with System and Algorithm Co-design Mu Li Thesis Defense CSD, CMU Feb 2nd, 2017
nx min w f i (w) Distributed systems i=1 Large scale optimization methods Large-scale problems 2
Large Scale Machine Learning More complex models Machine learning learns from data More data better accuracy can use more complex models Accuracy Data size 3
Ads Click Prediction Predict if an ad will be clicked Each ad impression is an example Logistic regression Single machine processes 1 million examples per second 4
Ads Click Prediction Predict if an ad will be clicked 700 Each ad impression is an example Logistic regression Single machine processes 1 million examples per second A typical industrial size problem has 100 billion examples Training data size (TB) 525 350 175 10 billion unique features 0 2010 2011 2012 2013 2014 Year 5
Image Recognition Recognize the object in an image Convolutional neural network A state-of-the-art network Hundreds of layers Billions of floating-point operation for processing a single image 6
Distributed Computing for Large Scale Problems Distribute workload among many machines Widely available thanks to cloud providers (AWS, GCP, Azure) switch switch machine machine switch switch machine machine 7
Distributed Computing for Large Scale Problems Distribute workload among many switch switch machines switch switch Widely available thanks to cloud providers (AWS, GCP, Azure) machine machine Challenges machine machine Limited communication bandwidth (10x less than memory bandwidth) Large synchronization cost (1ms latency) Job failures Failure rate (%) 24 16 8 0 100 1000 10000 8 Machine time (hour)
Distributed Optimization for Large Scale ML nx m[ min w i=1 f i (w) i=1 I i = {1, 2,...,n} X i2i 1 @f i (w t ) X w + w t t+1 i2i 2 @f i (w t ) Challenges Massive communication traffic Expensive global synchronization X i2i m @f i (w t ) 9
Scaling Distributed Machine Learning Distributed Systems Distributed Systems Distributed Systems Large Scale Optimization Large data size, complex models Fault tolerant Communication efficient Convergence guarantee Easy to use 10
Scaling Distributed Machine Learning Distributed Systems Distributed Systems Distributed Systems Large Scale Optimization Parameter Server DBPG for non-convex non-smooth f i MXNet for deep learning EMSO for efficient minibatch SGD 11
Scaling Distributed Machine Learning Distributed Systems Distributed Systems Distributed Systems Large Scale Optimization Parameter Server DBPG for non-convex non-smooth f i MXNet for deep learning EMSO for efficient minibatch SGD 12
Existing Open Source Systems in 2012 MPI (message passing interface) Hard to use for sparse problems No fault tolerance Key-value store, e.g. redis Expensive individual key-value pair communication Difficult to program on the server side Hadoop/Spark BSP data consistency makes efficient implementation challenging 13
Parameter Server Architecture Model [Smola 10, Dean 12] update Server machines push pull Worker machines Training data 14
Keys Features of our Implementation [Li et al, OSDI 14] Trade off data consistency for speed X Flexible consistency models User-defined filters Fault tolerance with chain replication X 15
Flexible Consistency Model iter 0 gradient push & pull execute after finished dependency iter 1 gradient push & pull Sequential / BSP 1 2 3 4 iter 0 gradient push & pull iter 1 gradient push & pull Eventual / Total asynchronous [Smola 10] 1 2 3 4 no dependency Flexible models via task dependency graph Bounded delay / SSP [Langford 09, Cipar 13] 1 2 3 4 5 16
User-defined Filters machine Data User defined encoder/decoder for efficient communication Encoder Lossless compression General data compression: LZ, LZR,.. efficient communication Lossy compression Decoder Random skip Fixed-point encoding Data machine 17
Fault Tolerance with Chain Replication Model is partitioned by consistent hashing Chain replication ack ack worker 0 server 0 server 1 push push Option: aggregation reduces backup traffic worker 0 ack push ack worker 1 ack push server 0 server 1 push 18
Scaling Distributed Machine Learning Distributed Systems Distributed Systems Distributed Systems Large Scale Optimization Parameter Server DBPG for non-convex non-smooth f i MXNet for deep learning EMSO for efficient minibatch SGD 19
Proximal Gradient Method [Combettes 09] min w2 nx i=1 f i (w)+h(w) f i : continuously differentiable but not necessarily convex h: convex but possibly non-smooth w t+1 =Prox t " w t t n X f i (w t ) # Iterative update i=1 where Prox (x) := argmin y2 h(y)+ 1 2 kx yk2 20
Delayed Block Proximal Gradient [Li et al, NIPS 14] Algorithm design tailored for parameter model server implementation Update a block of coordinates each time Allow delay among blocks Use filters during communication data Only 300 lines of codes 21
Convergence Analysis Assumptions: Block Lipschitz continuity: within block Delay is bounded by τ L var, cross blocks L cor Lossy compressions such as random skip filter and significantly-modified filter DBPG converges to a stationary point if the learning rate is chosen as t < 1 L var + L cor 22
Experiments on Ads Click Prediction Real dataset used in production Time to achieve the same objective value 170 billion examples, 65 billion unique features, 636 TB in total sequential 2 computing waiting 1000 machines Sparse logistic regression time (hour) 1.5 1 1.6x best trade-off min w nx log(1 + exp( y i hx i,wi)) + kwk 1 i=1 0.5 0 0 1 2 4 8 16 Maximal delay τ 23
Filters to Reduce Communication Traffic 100 Server Key caching Traffic (%) 75 50 25 2x 40x 40x Cache feature IDs on both sender and receiver Data compression 0 Baseline Key Caching Compre- ssing KKT Filter KKT filter 100 Worker Shrink gradient to 0 based on the KKT condition Traffic (%) 75 50 25 2x 2.5x 12x 24 0 Baseline Key Caching Compre- ssing KKT Filter
Scaling Distributed Machine Learning Distributed Systems Distributed Systems Distributed Systems Large Scale Optimization Parameter Server DBPG for non-convex non-smooth f i MXNet for deep learning EMSO for efficient minibatch SGD 25
Deep Learning is Unique Complex workloads Heterogeneous computing Easy to use programming interface deep learning trend in the past 5 years 26
Key Features of MXNet [Chen et al, NIPS 15 workshop] (corresponding author) Easy-to-use front-end Mixed programming Scalable and efficient back-end Computation and memory optimization Auto-parallelization Scaling to multiple GPU/machines 27
Mixed Programming Declarative programs are easy to optimize Imperative programming is flexible e.g. Numpy, Matlab, Torch, e.g. TensorFlow, Theano, Caffe, import mxnet as mx net = mx.symbol.variable('data') net = mx.symbol.fullyconnected( data=net, num_hidden=128) net = mx.symbol.softmaxoutput(data=net) model = mx.module.module(net) model.forward(data=c) model.backward() Good for defining the neural network import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b print(c) c += 1 Good for updating and interacting with the neural network 28
Back-end System Front-end import mxnet as mx a = mx.nd.zeros((100, 50)) b = mx.nd.ones((100, 50)) c = a * b c += 1 import mxnet as mx net = mx.symbol.variable('data') net = mx.symbol.fullyconnected( data=net, num_hidden=128) net = mx.symbol.softmaxoutput(data=net) texec = mx.module.module(net) texec.forward(data=c) texec.backward() Back-end a b c weight Optimization Memory optimization + 1 fullc softmax bias Operator fusion and runtime compilation Scheduling Auto-parallelization 29
Scale to Multiple GPU Machines Hierarchical parameter server Network Switch 1.25 GB/s 10 Gbit Ethernet CPU CPUs Level-2 Servers 15.75 GB/s PCIe 3.0 16x PCIe Switch Level-1 Servers 63 GB/s 4 PCIe 3.0 16x GPUs Workers GPU GPU GPU GPU 1000 lines of codes 30
Experiment Setup Minibatch SGD 1.2 million images with 1000 classes Draw a random set of examples I t at Resnet 152-layer model iteration t EC2 P2.8 xlarge Update 8 K80 GPUs per machine CPU w t+1 = w t t I t X i2i t @f i (w t ) PCIe switches Synchronized updating GPU 0-7 31
Communication Cost Fix #GPUs per machine 0.7 0.575 1 GPU/machine 2 GPU/machine 4 GPU/machine time (sec) 0.45 8 GPU/machine 0.325 0.2 0 32 64 96 128 # of GPUs 32
Scalability 115x speedup 1 Communication cost batch size/gpu=4 0.75 batch size/gpu=8 time (sec) 0.5 batch size/gpu=16 0.25 0 0 32 64 96 128 # of GPUs 33
Convergence 0.8 Increase learning rate by 5x Increase learning rate by 10x, decrease it at epoch 50, 80 Top-1 validation accuracy 0.625 0.45 0.275 batch size=256 batch size=2,560 batch size=5,120 0.1 0 30 60 90 120 # of epochs 34
Scaling Distributed Machine Learning Distributed Systems Distributed Systems Distributed Systems Large Scale Optimization Parameter Server DBPG for non-convex non-smooth f i MXNet for deep learning EMSO for efficient minibatch SGD 35
Minibatch SGD Large batch size b in SGD Better parallelization within a batch Less switching/communication cost Better system performance Small batch size b Faster convergence convergence O(1/ p N + b/n) Worse rate Batch size (b) N: number of examples processed 36
Motivation [Li et al, KDD 14] Better system Improve converge rate for large batch size performance Example variance decreases with batch size Solve a more accurate optimization subproblem over each batch convergence Worse rate Batch size (b) 37
Efficient Minibatch SGD Define f It (w) := X i2i t f i (w). Minibatch SGD solves w t = argmin w2 apple first-order approximation conservative penalty f (w It t 1 )+h@f (w It t 1 ),w w t 1 i + 1 kw w t 1 k 2 2 2 t EMSO solves the subproblem at each iteration w t = argmin w2 apple f It (w)+ 1 2 t kw w t 1 k 2 2 exact objective p For convex f i, choose t = O(b/ N). EMSO has convergence rate O(1/ p N) (compared to p O(1/ N + b/n) ) 38
Experiment Ads click prediction with fixed run time Single machine 10 machines 0.035 SGD 0.22 SGD 0.029 EMSO 0.214 EMSO Objective 0.023 Objective 0.208 0.016 0.201 0.01 1e3 1e4 1e5 0.195 1e3 1e4 1e5 Batch size Batch size Extended to deep learning in [Keskar et al, arxiv 16] 39
Large-scale problems min w nx i=1 f i (w) Reduce communication cost Distributed systems Large scale optimization Co-design Communicate less With appropriate computational frameworks and algorithm Message compression design, distributed machine learning can be made simple, Relaxed data consistency fast, and scalable, both in theory and in practice. 40
Acknowledgement Committee members QQ & Alex Advisors with other 13 collaborators 41
Backup slides 42
Scaling to 16 GPUs in a Single Machine 1 0.75 15x Comm Cost bs/gpu=2 bs/gpu=4 time (sec) 0.5 0.25 bs/gpu=8 bs/gpu=16 Communication dominates 0 0 4 8 12 16 # of GPUs 43
Compare to a L-BFGS Based System computing waiting Objective System A Parameter Server Time (hour) 5 3.75 2.5 1.25 0.1 1 10 0 System A Parameter Time (hour) Server 44
Sections not Covered AdaDelay: model the actual delay for asynchronized SGD Parsa: data placement to reduce communication cost Difacto: large scale factorization machine 45