ジョンとヨーコのイマジン日記

キョウとアンナのラヴラブダイアイリー改め、ジョンとヨーコのイマジン日記です。

コレスキー分解された分散共分散行列を持つ正規分布で事後分布を近似する変分ベイズ(in Julia)

アルゴリズムの導出

変分事後分布に正規分布を仮定して、変分パラメータを勾配法で推定する方法はそれなりによく知られている。例えば須山「ベイズ推論による機械学習入門」を参照されたい。

そこでは変分事後分布はすべて互いに独立であるようにとるが、事前分布が独立でも事後分布は独立にならないので、この独立性の仮定を少し緩めてやったほうが事後分布の近似は良くなると思われる。

ところで、多変量正規分布(相関のある正規分布)に従う乱数のベクトル x がほしいときは、分散共分散行列  \Sigma について  \Sigma = LL^\top となるような下三角行列  L を用いて、 x = \mu + Lz とする方法が広く使われている。 \Sigma = LL^\top なる分解をコレスキー分解と呼ぶ。ここで \mu は平均ベクトル、z は標準正規乱数のベクトルである。

さて、モデルのすべての未知パラメータを w とまとめて置き、その変分事後分布 q(w) を平均 \mu、分散共分散行列 LL^\top の多変量正規分布(Stan でいう multi_normal_cholesky; 22.3 Multivariate Normal Distribution, Cholesky Parameterization | Stan Functions Reference)とする方法を考えてみる。

\displaystyle \log q(w) = - \frac{1}{2}\left(\log \det\left\{L L^\top\right\} + (w-\mu)^\top (L L^\top)^{-1} (w-\mu) \right) + \mathrm{const.}

ここで  \mathrm{const.} は変分パラメータ(知りたいターゲット)の \muL のとり方に依存しない定数項で、微分すれば消える。

多くの場合、コレスキー分解では L の対角成分が正になるようにするが、計算したあとに符号を変えてもいいので、今回は気にしないことにする。

標準正規乱数のベクトル \tilde \varepsilon を使ってサンプル  \tilde{w} = \mu + L \tilde{\varepsilon} を得ると、これは近似事後分布からのサンプルそのものなので、観測された目的変数を  Y、 説明変数を  X として、尤度を  p(Y|X,w) と書くと真の事後分布と変分事後分布のカルバックライブラ距離は、

 \displaystyle \operatorname{KL}(q(w)\|p(w|Y,X)) \approx \log q(\tilde w)- \log p(\tilde w) - \log p(Y|X,\tilde w)+\mathrm{const.}

と書ける。このようなアイデアを再パラメータ化トリック(re-parameterization trick; [1505.05424] Weight Uncertainty in Neural Networks)と呼ぶ。

この   \operatorname{KL}微分して、勾配法で \muL を更新してやる。真の事後分布とのカルバック・ライブラ距離が近い正規分布を求めるのが目標である。

 \log q(\tilde w) \tilde{w} = \mu + L \tilde{\varepsilon} を代入すると

 \displaystyle \log q(w) = - \frac{1}{2}\left(\log \det\left\{L L^\top\right\} + \tilde \varepsilon^\top \tilde \varepsilon \right)+\mathrm{const.} \\
 \displaystyle = - \frac{1}{2}\log \det\left\{L L^\top\right\} +\mathrm{const.}

である。

また、正規分布でなければいけないわけではないが、計算が楽になるようにモデルのすべての未知パラメータ w に対して、平均0、分散 \lambda^{-1} の単変量正規事前分布 p(w) を設定する(この w はパラメータのベクトル w の各要素だが、見やすさのため添字は省略した)と、

\displaystyle \log p(w) =\frac{1}{2} \log \lambda - \lambda \frac{w^2}{2}+\mathrm{const.}

なので、
 \displaystyle \frac{\partial}{\partial \mu}\log q(\tilde w) =0
 \displaystyle \frac{\partial}{\partial L}\log q(\tilde w) = A
 \displaystyle \frac{\partial}{\partial \mu}\log p(\tilde w) =-\lambda \tilde w
 \displaystyle \frac{\partial}{\partial L}\log p(\tilde w) = -\lambda \tilde w \varepsilon^\top

AL の対角成分の逆数からなる対角行列である。理由は行列の要素を愚直に書き下してみればわかる。

まとめると次のアルゴリズムが得られる:

  1. 学習率 \alpha と 事前分布の精度パラメータ  \lambda を設定。
  2. \muL を適当に初期化し以下を繰り返す。
  3. すべてのパラメータに対して、標準正規乱数 \tilde{\varepsilon} を使ってサンプル  \tilde{w} = \mu + L \tilde{\varepsilon} を得る。
  4.  g = -\frac{d}{dw} \log p(Y|X, \tilde w) を計算。
  5.  g_\mu =  g + \lambda\tilde{w} とする。
  6.  g_L = (g + \lambda \tilde{w}) \tilde{\varepsilon}^\top - A とする。
  7.  \mu \leftarrow \mu - \alpha g_\mu で更新。
  8.  L \leftarrow L - \alpha g_L で更新。

Julia のコード

Julia には ForwardDiff など、自動微分のパッケージがあるので対数尤度の微分のところは自分で計算しなくてよい(場合もある)。

今回は乱数で適当に作ったデータでロジスティック回帰とポアソン回帰をやってみる。

サンプルサイズは100とした。

こんなふうだ:

module VItools

using Distributions
using ForwardDiff
using Random
using LinearAlgebra

function inittri(rng,n)
    L = zeros(n,n)
    for i in 1:n
        L[i:n,i] = randn(rng,n-i+1)
    end
    return LowerTriangular(L)
end

function prodLower(a,b)
    n = length(a)
    L = zeros(n,n)
    for i in 1:n
        for j in i:n
        L[j,i] = a[j]*b[i]
        end
    end
    return LowerTriangular(L)
end

function GDChol(f, par0, lambda, lr, maxiter) 
    rng = Random.default_rng()
    len = length(par0)
    mu = randn(rng,len)
    L = inittri(rng,len)
    g(beta) = ForwardDiff.gradient(beta0 -> f(beta0), beta)
    logloss = zeros(maxiter)
    for i in 1:maxiter
    epsilon = randn(rng,len)
    beta = mu + L*epsilon
    fx = f(beta)
    gvec = g(beta)
    g_mu = gvec + lambda*beta
    g_L = prodLower((gvec + lambda*beta), epsilon) - Diagonal(1.0 ./diag(L))
    mu -= lr * g_mu
    L -= lr * g_L
    logloss[i] = fx
    end
    return mu, L, logloss
end

function GDdiag(f, par0, lambda, lr, maxiter)
    rng = Random.default_rng()
    len = length(par0)
    mu = randn(rng,len)
    L = randn(rng,len)
    g(beta) = ForwardDiff.gradient(beta0 -> f(beta0),beta)
    logloss = zeros(maxiter)
    for i in 1:maxiter
    epsilon = randn(rng,len)
    beta = mu + L.*epsilon
    fx = f(beta)
    gvec = g(beta)
    g_mu = gvec + lambda*beta
    g_L = (gvec + lambda*beta).*epsilon .- 1.0 ./ L
    mu -= lr * g_mu
    L -= lr * g_L
    logloss[i] = fx
    end
    return mu, L, logloss
end

end

using Distributions
using ForwardDiff
using Random
using Plots
using Optim
using StatsFuns
using LinearAlgebra

function poisonloss(beta,y,X)
    Xbeta = X*beta
    return -sum(y[i] * Xbeta[i] - exp(Xbeta[i]) for i in eachindex(y))
end

function logisticloss(beta,y,X)
    Xbeta = X*beta
    return sum(y[i] * log1p(exp(-Xbeta[i])) + (1 - y[i]) * log1p(exp(Xbeta[i])) for i in eachindex(y))
end

rng = Random.default_rng()
Random.seed!(123456)
x = randn(rng,100)
X = [ones(100) x]
beta = [2.0, -1.0]
y = rand.(rng, Poisson.(exp.(X*beta)))
f(beta) = poisonloss(beta,y,X)
betaini = rand(rng, 2)
@time out = VItools.GDChol(f, betaini, 1.0, 1.0e-5, 2000)
@time outd = VItools.GDdiag(f, betaini, 1.0, 1.0e-5, 2000)

plot(out[3],label="Cholesky")
plot!(outd[3],label="diagonal")
png("./Desktop/loss1.png")

eps = randn(rng,2,10000)
betasmp = out[1].+out[2]*eps
betasmpd = outd[1].+outd[2].*eps

xs = sort(x)
Xs = [ones(100) xs]
pred = rand.(rng,Poisson.(exp.(Xs*betasmp)))
predd = rand.(rng,Poisson.(exp.(Xs*betasmpd)))

predmean = mean(pred,dims=2)
predmeand = mean(predd,dims=2)

lwr = [quantile(pred[i,:],0.025) for i in 1:100]
upr = [quantile(pred[i,:],0.975) for i in 1:100]
lwrd = [quantile(predd[i,:],0.025) for i in 1:100]
uprd = [quantile(predd[i,:],0.975) for i in 1:100]

scatter(x,y,msw=0,color="gray",legend=false)
plot!(xs,predmean, color="blue")
plot!(xs,predmeand, color="orange")
plot!(xs,lwr,color="blue", linestyle = :dash)
plot!(xs,upr,color="blue", linestyle = :dash)
plot!(xs,lwrd,color="orange", linestyle = :dash)
plot!(xs,uprd,color="orange", linestyle = :dash)
png("./Desktop/pi1.png")

####
y2 = rand.(rng,Bernoulli.(logistic.(X*beta)))
f2(beta) = logisticloss(beta,y2,X)

@time out2 = VItools.GDChol(f2, betaini, 1.0, 1.0e-3, 2000)
@time outd2 = VItools.GDdiag(f2, betaini, 1.0, 1.0e-3, 2000)

plot(out2[3],label="Cholesky")
plot!(outd2[3],label="diagonal")
png("./Desktop/loss2.png")

betasmp2 = out2[1].+out2[2]*eps
betasmpd2 = outd2[1].+outd2[2].*eps
post = logistic.(Xs*betasmp2)
postd = logistic.(Xs*betasmpd2)
postmean = mean(post,dims=2)
postmeand = mean(postd,dims=2)
lwr2 = [quantile(post[i,:],0.025) for i in 1:100]
upr2 = [quantile(post[i,:],0.975) for i in 1:100]
lwrd2 = [quantile(postd[i,:],0.025) for i in 1:100]
uprd2 = [quantile(postd[i,:],0.975) for i in 1:100]

scatter(xs,y2,msw=0,color="gray",legend=false)
plot!(xs,postmean, color="blue")
plot!(xs,postmeand, color="orange")
plot!(xs,lwr2,color="blue", linestyle = :dash)
plot!(xs,upr2,color="blue", linestyle = :dash)
plot!(xs,lwrd2,color="orange", linestyle = :dash)
plot!(xs,uprd2,color="orange", linestyle = :dash)
png("./Desktop/pi2.png")

opt = optimize(f,betaini,method=BFGS(),autodiff=:forward)
betahat = Optim.minimizer(opt)
out[1]
betahat

inv(Symmetric(ForwardDiff.hessian(f,betahat)))
out[2]*out[2]'

opt2 = optimize(f2,betaini,method=BFGS(),autodiff=:forward)
betahat2 = Optim.minimizer(opt2)
out2[1]
betahat2

inv(Symmetric(ForwardDiff.hessian(f2,betahat2)))
out2[2]*out2[2]'

function dp(a,b,y,x)
    exp(sum(logpdf(Poisson(exp(a+x[i]*b)),y[i]) for i in eachindex(y)) + logpdf(Normal(0.0,1.0),a) + logpdf(Normal(0.0,1.0),b))
end 

function dp2(a,b,y,x)
    exp(sum(logpdf(Bernoulli(logistic(a+x[i]*b)),y[i]) for i in eachindex(y)) + logpdf(Normal(0.0,1.0),a) + logpdf(Normal(0.0,1.0),b))
end 

av = 1.8:0.005:2.1
bv = -1.2:0.005:-0.9
Xv = repeat(reshape(av, 1, :), length(bv), 1)
Yv = repeat(bv, 1, length(av))
Zv = map((a,b) -> dp(a,b,y,x), Xv, Yv)
contour(av, bv, Zv, c=:grays, linewidth=2, colorbar=false)
scatter!(betasmp[1,:], betasmp[2,:], ma=0.1, ms=3, msw=0, color=:royalblue, label="Cholesky")
scatter!(betasmp[1,:], betasmpd[2,:], ma=0.1, ms=3, msw=0, color=:orange, label="diagonal")
png("./Desktop/contour1.png")

av = 0.5:0.01:2.5
bv = -1.5:0.01:0.5
Xv = repeat(reshape(av, 1, :), length(bv), 1)
Yv = repeat(bv, 1, length(av))
Zv = map((a,b) -> dp2(a,b,y2,x), Xv, Yv)
contour(av, bv, Zv, c=:grays, linewidth=2, colorbar=false)
scatter!(betasmp2[1,:], betasmp2[2,:], ma=0.1, ms=3, msw=0, color=:royalblue, label="Cholesky")
scatter!(betasmpd2[1,:], betasmpd2[2,:], ma=0.1, ms=3, msw=0, color=:orange, label="diagonal")
png("./Desktop/contour2.png")

GDdiag は変分事後分布を独立(分散共分散行列が対角行列)とした方法です。

Julia は 1.6 を使ってます。

julia> VERSION
v"1.6.3"

いくつかの結果

ポアソン回帰で負の対数尤度が小さくなっていく様子です。

ロス(ポアソン分布)

95%予測区間です。

予測区間ポアソン分布)

青が提案手法(コレスキー分解された分散共分散行列を持つ正規分布で事後分布を近似)オレンジが従来法(独立な正規分布で事後分布を近似)ですがほぼ一致。

ロジスティック回帰で負の対数尤度が小さくなっていく様子です。

ロス(ベルヌーイ分布)

95%信用区間です。線がぐにゃっとしていたほうがおもしろいという理由で予測区間(データのばらつき)でなく信用区間(パラメータのばらつき)をプロットしてます。

信用区間(ベルヌーイ分布)

真の事後分布は尤度と事前分布の積の定数倍なので、それを等高線で描いて変分事後分布からサンプリングした乱数を散布図で重ねました。

事後分布との比較(ポアソン分布)
事後分布との比較(ベルヌーイ分布)

ロジスティック回帰のほうは大差ないけど、ポアソン回帰のほうは相関を入れたほうが真の事後分布に近い形になりました。

最尤法と比較するとこんな感じ:

#ポアソン回帰
julia> betahat = Optim.minimizer(opt)
2-element Vector{Float64}:
  1.9615359831503478
 -1.0272951225324904

julia> out[1]
2-element Vector{Float64}:
  1.9609649365590203
 -1.0309899950651866

julia> betahat
2-element Vector{Float64}:
  1.9615359831503478
 -1.0272951225324904

julia> inv(Symmetric(ForwardDiff.hessian(f,betahat)))
2×2 Symmetric{Float64, Matrix{Float64}}:
 0.00187025   0.000927429
 0.000927429  0.000713584

julia> out[2]*out[2]'
2×2 Matrix{Float64}:
 0.00180625  0.00101994
 0.00101994  0.000909128

#ロジスティック回帰
julia> betahat2 = Optim.minimizer(opt2)
2-element Vector{Float64}:
  1.4335809051433548
 -0.7636614032163528

julia> out2[1]
2-element Vector{Float64}:
  1.3288174019948678
 -0.7083837999431308

julia> betahat2
2-element Vector{Float64}:
  1.4335809051433548
 -0.7636614032163528

julia> inv(Symmetric(ForwardDiff.hessian(f2,betahat2)))
2×2 Symmetric{Float64, Matrix{Float64}}:
  0.0758743  -0.0259558
 -0.0259558   0.0729946

julia> out2[2]*out2[2]'
2×2 Matrix{Float64}:
  0.050385   -0.0102875
 -0.0102875   0.0695762

応用の可能性

変分ベイズでは必ずしもすべてのパラメータの変分事後分布を独立と設定する必要はなく、いくつかのブロックごとに独立としてもいい。たとえば以下のような場合は提案手法がわりとかんたんに実装できると思う:

  • シェイプパラメータとスケールパラメータがあるような分布で、シェイプとスケールを独立にする
  • ニューラルネットの層ごとに独立とする