Auto-Encoding Variational Bayes Diederik P (Durk) Kingma, Max Welling University of Amsterdam Ph.D. Candidate, advised by Max Durk Kingma D.P. Kingma Max Welling
Problem class Directed graphical model: x : observed variable z : latent variables (continuous) θ : model parameters pθ(x,z): joint PDF Factorized, differentiable Hard case: intractable posterior distribution pθ(z x) e.g. neural nets as components We want fast approximate posterior inference per datapoint After inference, learning params is easy D.P. Kingma 2
Latent variable generative model latent variable model: learn a mapping from some latent variable z to a complicated distribution on x. p(x) = p(x, z) dz where p(x, z) =p(x z)p(z) p(z) = something simple p(x z) =f(z) Can we learn to decouple the true explanatory factors underlying the data distribution? E.g. separate identity and expression in face images z 2 x 2 f Image from: Ward, A. D., Hamarneh, G.: 3D Surface Parameterization Using Manifold Learning for Medial Shape Representation, Conference on Image Processing, Proc. of SPIE Medical Imaging, 2007 IFT6266: Representation (Deep) Learning Aaron Courville x 3 z 1 x 1 10 10
Variational autoencoder (VAE) approach Leverage neural networks to learn a latent variable model. p(x) = p(x, z) dz where p(x, z) =p(x z)p(z) p(z) = something simple p(x z) =f(z) z 2 x 2 z : f x 3 z 1 x 1 f(z) : x : IFT6266: Representation (Deep) Learning Aaron Courville 11 11
What VAE can do? x2 z2 f x3 z1 MNIST: z2 Frey Face dataset: Expression z2 x1 Face manifold (b) Learned MNIST manifold z1 ns of learned data manifold for generative models with two-dimensional latent AEVB. Since the prior of the latent space is Gaussian, linearly spaced coorrepresentation (Deep) Learning Courville quareift6266: were transformed through the inverse CDF of the Aaron Gaussian to produce ariables z. For each of these values z, we plotted the corresponding generative (a) Learned Pose Frey Face manifoldz1 Figure 4: Visualisations of learned data manifol 12 space, learned with AEVB. Since the prior12of
The inference / learning challenge Where does z come from? The classic directed model dilemma. Computing the posterior p(z x) is intractable. We need it to train the directed model. z 2? x 2 z : f x 3 z 1 x 1 f(z) : x : IFT6266: Representation (Deep) Learning Aaron Courville 13 13
Auto-Encoding Variational Bayes Idea: Learn neural net to approximate the posterior qφ(z x) with 'variational parameters' φ one-shot approximate inference akin to the recognition model in Wake-Sleep Construct estimator of the variational lower bound which we can optimize jointly w.r.t. φ jointly with θ -> Stochastic gradient ascent D.P. Kingma 4
Variational Lower Bound of the marg. lik. D.P. Kingma 5
Monte Carlo estimator of the variational bound Can we differentiate through the sampling process w.r.t. φ? D.P. Kingma
Variational Autoencoder (VAE) Where does z come from? The classic DAG problem. The VAE approach: introduce an inference machine q φ (z x) that learns to approximate the posterior p θ (z x). - Define a variational lower bound on the data likelihood: p θ (x) L(θ, φ,x) L(,,x)=E q (z x) [log p (x, z) log q (z x)] = E q (z x) [log p (x z) + log p (z) log q (z x)] = D KL (q (z x) p (z)) + E q (z x) [log p (x z)] What is q φ (z x)? IFT6266: Representation (Deep) Learning Aaron Courville 14 14
Variational Autoencoder (VAE) Where does z come from? The classic DAG problem. The VAE approach: introduce an inference machine q φ (z x) that learns to approximate the posterior p θ (z x). - Define a variational lower bound on the data likelihood: p θ (x) L(θ, φ,x) L(,,x)=E q (z x) [log p (x, z) log q (z x)] = E q (z x) [log p (x z) + log p (z) log q (z x)] = D KL (q (z x) p (z)) + E q (z x) [log p (x z)] reconstruction term What is q φ (z x)? IFT6266: Representation (Deep) Learning Aaron Courville 14 14
Variational Autoencoder (VAE) Where does z come from? The classic DAG problem. The VAE approach: introduce an inference machine q φ (z x) that learns to approximate the posterior p θ (z x). - Define a variational lower bound on the data likelihood: p θ (x) L(θ, φ,x) L(,,x)=E q (z x) [log p (x, z) log q (z x)] = E q (z x) [log p (x z) + log p (z) log q (z x)] = D KL (q (z x) p (z)) + E q (z x) [log p (x z)] regularization term reconstruction term What is q φ (z x)? IFT6266: Representation (Deep) Learning Aaron Courville 14 14
VAE Inference model The VAE approach: introduce an inference model q φ (z x) that learns to approximates the intractable posterior p θ (z x) by optimizing the variational lower bound: L(θ, φ,x)= D KL (q φ (z x) p θ (z)) + E qφ (z x) [log p θ (x z)] We parameterize q φ (z x) with another neural network: q φ (z x) =q(z; g(x, φ)) z : p θ (x z) =p(x; f(z,θ)) z : g(x) : f(z) : x : x : IFT6266: Representation (Deep) Learning Aaron Courville 15 15
Key reparameterization trick Construct samples z ~ qφ(z x) in two steps: 1. ε ~ p(ε) (random seed independent of φ) 2. z = g(φ, ε, x) (differentiable perturbation) such that z ~ qφ(z x) (the correct distribution) Examples: if q(z x) ~ N(μ(x), σ(x)^2) ε ~ N(0,I) z = μ(x) + σ(x) * ε (approximate) Inverse CDF Much more possibilities (see paper) D.P. Kingma 7
Reparametrization trick Adding a few details + one really important trick Let s consider z to be real and q φ (z x) =N (z; µ z (x), σ z (x)) Parametrize z as z = µ z (x)+σ z (x)ϵ z where ϵ z = N (0, 1) (optional) Parametrize x a x = µ x (z)+σ x (z)ϵ x where ϵ x = N (0, 1) µ z (x) σ z (x) z : { { g(z) : f(z) : x : µ x (z) { σ x (z) { IFT6266: Representation (Deep) Learning Aaron Courville 16 16
SGVB estimator Really simple and appropriate for differentiation w.r.t. φ and θ! D.P. Kingma
Variational auto-encoder x p injected noise ε p(z) and p(x z) (decoder) z q q(z x) = N(μ,σ) (encoder) x D.P. Kingma 11
Why reparametrization helps September 19, 2016 1 / 6
Training with backpropagation! Due to a reparametrization trick, we can simultaneously train both the generative model p θ (x z) and the inference model q φ (z x) by optimizing the variational bound using gradient backpropagation. Objective function: L(θ, φ,x)= D KL (q φ (z x) p θ (z)) + E qφ (z x) [log p θ (x z)] Forward propagation z x Backward propagation q φ (z x) p θ (x z) ˆx IFT6266: Representation (Deep) Learning Aaron Courville 17 17
Auto-Encoding Variational Bayes Online algorithm repeat Backprop (Torch7 / Theano) e.g. Adagrad until convergence Scales to very large datasets! D.P. Kingma 9
Model used in experiments (noisy) negative reconstruction error D.P. Kingma regularization terms 10
Special case with Gaussian prior and posterior Suppose p(z) = N (z; 0, I ) Suppose q φ (z x) = N (z; µ φ (x), σφ 2(x)) Variational bound L = ln p θ (x) D KL (q φ (z x) p θ (z x) (1) = IE qφ (z x)[ln p θ (x z)] D KL (q φ (z x) p(z)) (2) Closed-form computation of KL divergence D KL (q(z x) p(z)) = D 2 + 1 2 D 2 ln σ j (x) µ j (x) 2 σ j (x) 2 d=1 Deterministic regularization, stochastic data term September 19, 2016 2 / 6
Results: Marginal likelihood lower bound D.P. Kingma 12
Results: Marginal log-likelihood Monte Carlo EM does not scale well to large datasets D.P. Kingma 13
Robustness to high-dimensional latent space D.P. Kingma 14
Effect of KL term: component collapse IFT6266: Representation (Deep) Learning Aaron Courville Figure from Laurent Dinh & Vincent Dumoulin 19 19
Component collapse & depth Deep model: some component collapse Deeper model: more component collapse IFT6266: Representation (Deep) Learning Aaron Courville Figures from Laurent Dinh & Vincent Dumoulin 20 20
Samples from MNIST (simple ancestral sampling) D.P. Kingma 15
2D Latent space: Frey Face z2 D.P. Kingma z1 16
2D Latent space: MNIST z2 D.P. Kingma z1 17
Labeled Faces in the Wild (random samples from generative model) D.P. Kingma 19
Conditional generation using M2, central pixels image September 19, 2016 3 / 6
Conditional generation: central pixels image September 19, 2016 4 / 6
Semi-supervised Learning with Deep Generative Models Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling They study two basic approaches: M1: Standard unsupervised feature learning ( self-taught learning ) - Train features z on unlabeled data, train a classifier to map from z to label y. - Generative model: (recall that x = data, z = latent features) p(z) =N (z 0, I); p (x z) =f(x; z, ), z M2: Generative semi-supervised model. p(y) =Cat(y ); p(z) =N (z 0, I); is the multinomial distribution, the cl p (x y, z) =f(x; y, z, ), labels are treated as latent y x z x IFT6266: Representation (Deep) Learning Aaron Courville 23 23
Semi-supervised Learning with Deep Generative Models Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling M1+M2: Combination semi-supervised model - Train generative semi-supervised model on unsupervised features z1 on unlabeled data, train a classifier to map from z1 to label z1. ead of the raw data. The result is a deep generativ p (x,y,z 1, z 2 )= p(y)p(z 2 )p (z 1 y, z 2 )p (x z 1 ), y and z above, and both and a y z 2 z 1 x IFT6266: Representation (Deep) Learning Aaron Courville 24 24
Semi-supervised Learning with Deep Generative Models Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling Appoximate posterior (encoder model) - Following the VAE strategy we parametrize the approximate posterior with a high capacity model, like a MLP or some other deep model (convnet, RNN, etc). M1: q (z x) =N (z µ (x), diag( 2 (x))), M2: q (z y, x) =N (z µ (y, x), diag( 2 (x))); q (y x) =Cat(y (x)), µ (x) - and ( 2 (x) are parameterized by deep MLPs, that can share parameters. M1: z M2: z y x x IFT6266: Representation (Deep) Learning Aaron Courville 25 25
Semi-supervised Learning with Deep Generative Models Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling M2: The lower bound for the generative semi-supervised model. - Objective with labeled data: log p (x,y) E q (z x,y) [log p (x y, z) + log p (y) + log p(z) log q (z x,y)]= L(x,y), - Objective without labels: posterior inference and the resulting bound for handling data points with an unobserved label X log p (x) E q (y,z x) [log p (x y, z) + log p (y) + log p(z) log q (y, z x)] = X q (y x)( L(x,y)) + H(q (y x)) = U(x). y - Semi-supervised objective: X J = X L(x,y)+ X U(x) (x,y) epl x epu X y z x - actually, for classification, they use J = J + E epl (x,y) [ log q (y x)], IFT6266: Representation (Deep) Learning Aaron Courville 26 26
Semi-supervised MNIST classification results Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling Combination model M1+M2 shows dramatic improvement: Table 1: Benchmark results of semi-supervised classification on MNIST with few labels. N NN CNN TSVM CAE MTC AtlasRBF M1+TSVM M2 M1+M2 100 25.81 22.98 16.81 13.47 12.03 8.10 (± 0.95) 11.82 (± 0.25) 11.97 (± 1.71) 3.33 (± 0.14) 600 11.44 7.68 6.16 6.3 5.13 5.72 (± 0.049) 4.94 (± 0.13) 2.59 (± 0.05) 1000 10.7 6.45 5.38 4.77 3.64 3.68 (± 0.12) 4.24 (± 0.07) 3.60 (± 0.56) 2.40 (± 0.02) 3000 6.04 3.35 3.45 3.22 2.57 3.49 (± 0.04) 3.92 (± 0.63) 2.18 (± 0.04) 4 Experimental Results Full MNIST test error: 0.96% (for comparison, current SOTA: 0.78%). IFT6266: Representation (Deep) Learning Aaron Courville 27 27
Conditional generation using M2 September 19, 2016 5 / 6
Conditional generation using M2 September 19, 2016 6/6