Bayesian Deep Learning
Maurizio Filippone
EURECOM, Sophia Antipolis, France January 10th, 2019
Outline
1 Motivation
2 Probabilistic Deep Nets Scalable Inference
Connections with (Deep) Gaussian Processes
3 Some Results 4 Conclusions
2
Motivation
Quantification of Uncertainty with Expensive Models
• Climate modeling
Kennedy and O’Hagan,J-RSS-B, 2001
3
Quantification of Uncertainty with No Models
• Classification and progression of neurodegenerative diseases
A Unified Framework
A model might be expensive to simulate/inaccurate
• Emulate model/discrepancy using a surrogate
A model might not even be available
• Replace it with a flexible statistical model
Probabilistic Deep Models for Accurate Modeling and Quantification of Uncertainty
5
A Unified Framework
A model might be expensive to simulate/inaccurate
• Emulate model/discrepancy using a surrogate
A model might not even be available
• Replace it with a flexible statistical model
Probabilistic Deep Models for Accurate Modeling and Quantification of Uncertainty
A Unified Framework
A model might be expensive to simulate/inaccurate
• Emulate model/discrepancy using a surrogate
A model might not even be available
• Replace it with a flexible statistical model
Probabilistic Deep Models for Accurate Modeling and Quantification of Uncertainty
5
Probabilistic Deep Nets
Learning from Data – Function Estimation
• Take these two examples
−3 −2 −1 0 1 2 3
−3−2−10123
x
y
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
−3 −2 −1 0 1 2 3
0.00.20.40.60.81.0
x
y
• We are interested in estimating a function f(x) from data
• Most problems in Machine Learning can be cast this way!
6
Deep Neural Networks
• Implement a composition of parametric functions f(x) =f(L)
f(L−1)
· · ·f(1)(x)· · · with
f(l)(h) =g
h>W(l)
f(1)
x f(2) f(3) f(4) y
W(1) W(2) W(3) W(4)
Back-propagation – Probabilistic Interpretation Loss
• Inputs:X ={x1, . . . ,xN}
• Labels:Y ={y1, . . . ,yN}
• Weights:W ={W(1), . . . ,W(L)}
Quadratic Loss p(Y|X,W)∝exp(−Loss)
051015 0.00.40.8
• Back-propagation minimizes a loss function
• . . . equivalent as optimizing likelihoodp(Y|X,W)
8
Bayesian Inference
• Inputs:X ={x1, . . . ,xN}
• Labels:Y ={y1, . . . ,yN}
• Weights:W ={W(1), . . . ,W(L)}
p(W) p(W|Y,X)
p(W|Y,X) = p(Y|X,W)p(W) Z
p(Y|X,W)p(W)dW
Bayesian Deep Neural Networks
• Regression example
−3 −2 −1 0 1 2 3
−3−2−10123
x
y
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
−3 −2 −1 0 1 2 3
−3−2−10123
x
y
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
10
Bayesian Deep Neural Networks
• Classification example
−3 −2 −1 0 1 2 3
0.00.20.40.60.81.0
x
y
−3 −2 −1 0 1 2 3
0.00.20.40.60.81.0
x
y
Stochastic Variational Inference
• Bayesian inference is intractable due to this integral log [p(Y|X)] = log
Z
p(Y|X,W)p(W)dW
• Lower bound for log [p(Y|X)]
Eq(W)(log [p(Y|X,W)])−KL[q(W)kp(W)] , where q(W) approximates p(W|Y,X).
• Kullback-Leibler divergenceKL – “distance” betweenq andp Optimize the lower bound wrt the parameters of q(W)
Graves,NIPS, 2011
12
Stochastic Variational Inference
• Bayesian inference is intractable due to this integral log [p(Y|X)] = log
Z
p(Y|X,W)p(W)dW
• Lower bound for log [p(Y|X)]
Eq(W)(log [p(Y|X,W)])−KL[q(W)kp(W)] , whereq(W) approximates p(W|Y,X).
• Kullback-Leibler divergenceKL – “distance” betweenq andp
Optimize the lower bound wrt the parameters of q(W)
Graves,NIPS, 2011
Stochastic Variational Inference
• Bayesian inference is intractable due to this integral log [p(Y|X)] = log
Z
p(Y|X,W)p(W)dW
• Lower bound for log [p(Y|X)]
Eq(W)(log [p(Y|X,W)])−KL[q(W)kp(W)] , whereq(W) approximates p(W|Y,X).
• Kullback-Leibler divergenceKL – “distance” betweenq andp Optimize the lower bound wrt the parameters of q(W)
Graves,NIPS, 2011 12
Stochastic Variational Inference
• Assume that the likelihood factorizes p(Y|X,W) =Y
k
p(yk|xk,W)
• Doubly stochasticunbiasedestimate of the expectation term
• Mini-batch
Eq(W)(log [p(Y|X,W)])≈ n m
X
k∈Im
Eq(W)(log [p(yk|xk,W)])
• Monte Carlo
Eq(W)(log [p(yk|xk,W)])≈ 1 NMC
NMC
X
r=1
log[p(yk|xk,W˜r)]
with ˜Wr ∼q(W).
Stochastic Variational Inference
• Assume that the likelihood factorizes p(Y|X,W) =Y
k
p(yk|xk,W)
• Doubly stochasticunbiasedestimate of the expectation term
• Mini-batch
Eq(W)(log [p(Y|X,W)])≈ n m
X
k∈Im
Eq(W)(log [p(yk|xk,W)])
• Monte Carlo
Eq(W)(log [p(yk|xk,W)])≈ 1 NMC
NMC
X
r=1
log[p(yk|xk,W˜r)]
with ˜Wr ∼q(W).
13
Stochastic Variational Inference
• Assume a factorized Gaussian approximate posterior:
q(W) =Y
ijl
q
W(l)ij
=Y
ijl
N
µ(l)ij ,(σ2)(l)ij
(1)
• Reparameterization trick
( ˜W(l)r )ij =σ(l)ij ε(l)rij +µ(l)ij , with ε(l)rij ∼ N(0,1)
• Optimization wrt µ(l)ij ,(σ2)(l)ij with automatic differentiation
Kingma and Welling,ICLR, 2014
Stochastic Variational Inference
• Assume a factorized Gaussian approximate posterior:
q(W) =Y
ijl
q
W(l)ij
=Y
ijl
N
µ(l)ij ,(σ2)(l)ij
(1)
• Reparameterization trick
( ˜W(l)r )ij =σ(l)ij ε(lrij)+µ(l)ij , with ε(l)rij ∼ N(0,1)
• Optimization wrtµ(l)ij ,(σ2)(l)ij with automatic differentiation
Kingma and Welling,ICLR, 2014
14
Stochastic Gradient Optimization
En
^∇parqLowerBoundo
=∇parqLowerBound
−4 −2 0 2 4
−4−2024
●
−4 −2 0 2 4
−4−2024
●
●
●
● ●
● ●●
●
●
●●
●
●●●
●
●
● ●
●
● ●
●
●
●
●●
●
●
●
● ●
●
●
●
●
● ●
●
●
●
●
●
● ●
●
●
●
●
●
●
●●
●
● ●
●
●
●
●
● ●
●
●
●
●
●●
●
●
Stochastic Variational Inference - Simple Illustration parq0
=parq+αt
2 ^∇parq(LowerBound) αt →0
16
−4 −2 0 2 4
−4−2024
●
●
●
0 20 40 60 80 100
0.00.20.40.60.81.0
Iteration
p(par | data)
5 0 5
4 2 0 2 4
2 1 0
2.0 1.5 1.0 0.5 0.0 0.5
Form of Approximating Distribution
Approximating distributionq(W) can have the following forms:
• Fully factorized
−505
• Full covariance Σ =LL>
−505
• Normalizing Flows, Real NVPs, Stein VI - change of measure determined by det(Jacobian)
−505 −505
Rezende et al.,ICML, 2015 – Dinh et al.,ICLR, 2017 – Liu and Wang,NIPS, 2016
Form of Approximating Distribution
Approximating distributionq(W) can have the following forms:
• Fully factorized
−505
• Full covariance Σ =LL>
−505
• Normalizing Flows, Real NVPs, Stein VI - change of measure determined by det(Jacobian)
−505 −505
Rezende et al.,ICML, 2015 – Dinh et al.,ICLR, 2017 – Liu and Wang,NIPS, 2016
17
Form of Approximating Distribution
Approximating distributionq(W) can have the following forms:
• Fully factorized
−505
• Full covariance Σ =LL>
−505
• Normalizing Flows, Real NVPs, Stein VI - change of measure determined by det(Jacobian)
−505 −505
Initialization of SVI matters
• Initialization can be an issue
• We proposed a novel way to initialize SVI well
−10 −5 0 5 10
−2
−1 0 1 2
AFTER POOR INITIALIZATION
−10 −5 0 5 10
−2
−1 0 1 2
AFTER OUR INITIALIZATION
Rossi, Michiardi and Filippone,arXiv, 2018
18
Is There an Easier Way?
• Dropout is Variational Inference with Bernoulli-like q(W)
• At training time, apply dropout
Iteration 1 Iteration 2 Iteration 3 . . . . . .
• At test time, “sample” networks with different dropout masks
−3 −2 −1 0 1 2 3
−3−2−10123
x
y
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
●
Gaussian Processes as Infinitely-Wide Shallow Neural Nets
• Take W(i) ∼ N(0, αiI)
• Central Limit Theorem implies that F is Gaussian
...
...
Φ
X F
W(0) W(1)
• F has zero-mean
• cov(F) =Ep(W(0),W(1))[Φ(XW(0))W(1)W(1)>Φ(XW(0))>]
Neal,LNS, 1996 – Rasmussen and Williams, 2006
20
Gaussian Processes as Infinitely-Wide Shallow Neural Nets
• Take W(i) ∼ N(0, αiI)
• Central Limit Theorem implies that F is Gaussian
...
...
Φ
X F
W(0) W(1)
• F has zero-mean
• cov(F) =α1Ep(W(0))[Φ(XW(0))Φ(XW(0))>]
• Some choices of Φ lead to analytic expression of known kernels (RBF, Mat´ern, arc-cosine, Brownian motion, . . .)
Random Feature Expansions for DGPs - Bochner’s theorem
• Continuous shift-invariant covariance function k(xi−xj|θ) =σ2
Z
p(ω|θ) exp
ι(xi−xj)>ω
dω
• Monte Carlo estimate k(xi−xj|θ)≈ σ2
NRF NRF
X
r=1
z(xi|˜ωr)>z(xj|ω˜r) with
˜
ωr ∼p(ω|θ)
z(x|ω) = [cos(x>ω),sin(x>ω)]>
Rahimi and Recht,NIPS, 2008 - L´azaro-Gredilla et al.,JMLR, 2010
22
Random Feature Expansions for DGPs - Bochner’s theorem
• Continuous shift-invariant covariance function k(xi−xj|θ) =σ2
Z
p(ω|θ) exp
ι(xi−xj)>ω
dω
• Monte Carlo estimate k(xi−xj|θ)≈ σ2
NRF NRF
X
r=1
z(xi|˜ωr)>z(xj|ω˜r) with
˜
ωr ∼p(ω|θ)
z(x|ω) = [cos(x>ω),sin(x>ω)]>
Random Feature Expansions for DGPs
• Define
Φ(l)= s σ2
NRF(l) h
cos
F(l)Ω(l)
,sin
F(l)Ω(l) i
and
F(l+1) = Φ(l)W(l)
• We are stacking Bayesian linear models with p
W(l)·i
=N (0,I)
• Expansion of arc-cosine kernel yields ReLU activations!
Cutajar, Bonilla, Michiardi, Filippone,ICML, 2017
23
Random Feature Expansions for DGPs
• Define
Φ(l)= s σ2
NRF(l) h
cos
F(l)Ω(l)
,sin
F(l)Ω(l) i
and
F(l+1) = Φ(l)W(l)
• We are stacking Bayesian linear models with p
W(l)·i
=N (0,I)
• Expansion of arc-cosine kernel yields ReLU activations!
Random Feature Expansions make Deep GPs become DNNs
θ(0) θ(1)
Φ(0)
X F(1) Φ(1) F(2) Y
Ω(0) W(0) Ω(1) W(1)
Cutajar, Bonilla, Michiardi, Filippone,ICML, 2017
24
Some Results
Results - Model (Depth) Selection
Airline dataset (n = 5M+, d = 8)
2 3 4 5
0.2 0.3 0.4 0.5
log10(sec) Error rate
2 3 4 5
0.45 0.5 0.55 0.6
log10(sec) MNLL
2 10 20 30 2.6
2.7
·106
Layers Neg. Lower Bound
2 layers 10 layers 20 layers 30 layers SV-DKL
Cutajar, Bonilla, Michiardi, Filippone,ICML, 2017 25
Convolutional Nets
• Convolutional nets are widely used. . .
• . . .but they are known to be overconfident!
Convolution + ReLU
Convolution + ReLU
Pooling Pooling Fully connected layers Output
Calibration as a Measure of Quantification of Uncertainty
• Reliability diagrams
Predicted value
Fraction positives
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
27
Calibration as a Measure of Quantification of Uncertainty
• Reliability diagrams
Predicted value
Fraction positives
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
Calibration as a Measure of Quantification of Uncertainty
• Reliability diagrams - Under-confident predictions
Predicted value
Fraction positives
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
• We can extract the Expected Calibration Error (ece) score
• The brierscore is another measure of calibration
29
Calibration as a Measure of Quantification of Uncertainty
• Reliability diagrams - Overconfident predictions
Predicted value
Fraction positives
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
Reliability diagrams of modern Deep CNNs look like this! Bayesian treatment of filters fixes it!
Calibration as a Measure of Quantification of Uncertainty
• Reliability diagrams - Overconfident predictions
Predicted value
Fraction positives
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1
Reliability diagrams of modern Deep CNNs look like this!
Bayesian treatment of filters fixes it! 30
Bayesian CNNs are calibrated
• Inferring parameters of convolutional filter recovers calibration
• Example with Monte Carlo Dropout
0.0 0.5 1.0
0.0 0.5 1.0
cifar10-lenet
ECE: 0.0254 BRIER: 0.3951
1/8
0.0 0.5 1.0
0.0 0.5 1.0 ECE: 0.0145
BRIER: 0.3300
1/4
0.0 0.5 1.0
0.0 0.5 1.0 ECE: 0.0270
BRIER: 0.2827
1/2
0.0 0.5 1.0
0.0 0.5 1.0 ECE: 0.0495
BRIER: 0.2404
1
0.0 0.5 1.0
0.0 0.5 1.0
cifar10-resnet
ECE: 0.0485 BRIER: 0.3518
0.0 0.5 1.0
0.0 0.5 1.0 ECE: 0.0140
BRIER: 0.2765
0.0 0.5 1.0
0.0 0.5 1.0 ECE: 0.0295
BRIER: 0.2126
0.0 0.5 1.0
0.0 0.5 1.0 ECE: 0.0456
BRIER: 0.1799
0.0 0.5 1.0
predictive output 0.0
0.5 1.0
cifar100-resnet
ECE: 0.2201 BRIER: 0.8334
0.0 0.5 1.0
predictive output 0.0
0.5 1.0 ECE: 0.1357
BRIER: 0.6886
0.0 0.5 1.0
predictive output 0.0
0.5 1.0 ECE: 0.0492
BRIER: 0.5463
0.0 0.5 1.0
predictive output 0.0
0.5 1.0 ECE: 0.0270
BRIER: 0.4615
RELIABILITY DIAGRAM FOR MCD
Performance Evaluation of Bayesian CNNs
• Bayesian CNNs are calibrated and achieve better performance than post calibrated CNNs
2.0 1.5
MNIST-LENET
log10(ERR) 1.5 1.0
0.5 log10(MNLL)
3 2 1
log10(ECE)
2 1
log10(BRIER)
0.8 0.6 0.4
CIFAR10-LENET 0.0
0.5
1.5 1.0 0.5
0.6 0.4 0.2
0.75 0.50
CIFAR10-RESNET
0.0 0.5
2 1
0.75 0.50 0.25
1/8 1/4 1/2 1 proportion training data 0.4
0.2
CIFAR100-RESNET
1/8 1/4 1/2 1 proportion training data 0.5
1.0
1/8 1/4 1/2 1 proportion training data 1.5
1.0 0.5
1/8 1/4 1/2 1 proportion training data 0.2
0.0
GPDNN CGP SVDKL CAL MCD CNN+GP(RF) CNN+GP(SORF)
Tran et al.,AISTATS, 2019 32
Knowing When the Model Doesn’t Know
• Training on MNIST and test on not-MNIST
0.0 0.2 0.4 0.6 0.8 1.0 0
2000 4000 6000 8000
MCD
Average= 0.026 Average= 0.383
0.0 0.2 0.4 0.6 0.8 1.0 0
2000 4000 6000 8000
I-BLM INITIALIZATION
Average= 0.035 Average= 0.526
Test onMNIST Test onNOT-MNIST
Conclusions
Conclusions
• Inference for Deep Nets is hard
• Scalable stochastic-based approximate inference but...
• ... it is difficult to assess the impact approximations on quantification of uncertainty
• The connection between Deep Nets and Deep Gaussian processes can have implications on
• Understanding Deep Learning
• Deriving sensible priors for Deep Learning
• Improving inference borrowing algebraic/computational tricks from kernel literature
• Cool stuff
• New hardware
• Bayesian compression
Conclusions
• Inference for Deep Nets is hard
• Scalable stochastic-based approximate inference but...
• ... it is difficult to assess the impact approximations on quantification of uncertainty
• The connection between Deep Nets and Deep Gaussian processes can have implications on
• Understanding Deep Learning
• Deriving sensible priors for Deep Learning
• Improving inference borrowing algebraic/computational tricks from kernel literature
• Cool stuff
• New hardware
• Bayesian compression
34
Conclusions
• Inference for Deep Nets is hard
• Scalable stochastic-based approximate inference but...
• ... it is difficult to assess the impact approximations on quantification of uncertainty
• The connection between Deep Nets and Deep Gaussian processes can have implications on
• Understanding Deep Learning
• Deriving sensible priors for Deep Learning
• Improving inference borrowing algebraic/computational tricks from kernel literature
• Cool stuff
• New hardware
• Bayesian compression
We are hiring!
We are hiring PhDs, Post-docs and Assistant Professors
35
Acknowledgments