アルゴリズムの導出
変分事後分布に正規分布を仮定して、変分パラメータを勾配法で推定する方法はそれなりによく知られている。例えば須山「ベイズ推論による機械学習入門」を参照されたい。
そこでは変分事後分布はすべて互いに独立であるようにとるが、事前分布が独立でも事後分布は独立にならないので、この独立性の仮定を少し緩めてやったほうが事後分布の近似は良くなると思われる。
ところで、多変量正規分布(相関のある正規分布)に従う乱数のベクトル がほしいときは、分散共分散行列 について となるような下三角行列 を用いて、 とする方法が広く使われている。 なる分解をコレスキー分解と呼ぶ。ここで は平均ベクトル、 は標準正規乱数のベクトルである。
さて、モデルのすべての未知パラメータを とまとめて置き、その変分事後分布 を平均 、分散共分散行列 の多変量正規分布(Stan でいう multi_normal_cholesky; 22.3 Multivariate Normal Distribution, Cholesky Parameterization | Stan Functions Reference)とする方法を考えてみる。
ここで は変分パラメータ(知りたいターゲット)の と のとり方に依存しない定数項で、微分すれば消える。
多くの場合、コレスキー分解では の対角成分が正になるようにするが、計算したあとに符号を変えてもいいので、今回は気にしないことにする。
標準正規乱数のベクトル を使ってサンプル を得ると、これは近似事後分布からのサンプルそのものなので、観測された目的変数を 、 説明変数を として、尤度を と書くと真の事後分布と変分事後分布のカルバックライブラ距離は、
と書ける。このようなアイデアを再パラメータ化トリック(re-parameterization trick; [1505.05424] Weight Uncertainty in Neural Networks)と呼ぶ。
この を微分して、勾配法で と を更新してやる。真の事後分布とのカルバック・ライブラ距離が近い正規分布を求めるのが目標である。
は を代入すると
である。
また、正規分布でなければいけないわけではないが、計算が楽になるようにモデルのすべての未知パラメータ に対して、平均0、分散 の単変量正規事前分布 を設定する(この はパラメータのベクトル の各要素だが、見やすさのため添字は省略した)と、
なので、
は の対角成分の逆数からなる対角行列である。理由は行列の要素を愚直に書き下してみればわかる。
まとめると次のアルゴリズムが得られる:
- 学習率 と 事前分布の精度パラメータ を設定。
- と を適当に初期化し以下を繰り返す。
- すべてのパラメータに対して、標準正規乱数 を使ってサンプル を得る。
- を計算。
- とする。
- とする。
- で更新。
- で更新。
Julia のコード
Julia には ForwardDiff など、自動微分のパッケージがあるので対数尤度の微分のところは自分で計算しなくてよい(場合もある)。
今回は乱数で適当に作ったデータでロジスティック回帰とポアソン回帰をやってみる。
サンプルサイズは100とした。
こんなふうだ:
module VItools using Distributions using ForwardDiff using Random using LinearAlgebra function inittri(rng,n) L = zeros(n,n) for i in 1:n L[i:n,i] = randn(rng,n-i+1) end return LowerTriangular(L) end function prodLower(a,b) n = length(a) L = zeros(n,n) for i in 1:n for j in i:n L[j,i] = a[j]*b[i] end end return LowerTriangular(L) end function GDChol(f, par0, lambda, lr, maxiter) rng = Random.default_rng() len = length(par0) mu = randn(rng,len) L = inittri(rng,len) g(beta) = ForwardDiff.gradient(beta0 -> f(beta0), beta) logloss = zeros(maxiter) for i in 1:maxiter epsilon = randn(rng,len) beta = mu + L*epsilon fx = f(beta) gvec = g(beta) g_mu = gvec + lambda*beta g_L = prodLower((gvec + lambda*beta), epsilon) - Diagonal(1.0 ./diag(L)) mu -= lr * g_mu L -= lr * g_L logloss[i] = fx end return mu, L, logloss end function GDdiag(f, par0, lambda, lr, maxiter) rng = Random.default_rng() len = length(par0) mu = randn(rng,len) L = randn(rng,len) g(beta) = ForwardDiff.gradient(beta0 -> f(beta0),beta) logloss = zeros(maxiter) for i in 1:maxiter epsilon = randn(rng,len) beta = mu + L.*epsilon fx = f(beta) gvec = g(beta) g_mu = gvec + lambda*beta g_L = (gvec + lambda*beta).*epsilon .- 1.0 ./ L mu -= lr * g_mu L -= lr * g_L logloss[i] = fx end return mu, L, logloss end end using Distributions using ForwardDiff using Random using Plots using Optim using StatsFuns using LinearAlgebra function poisonloss(beta,y,X) Xbeta = X*beta return -sum(y[i] * Xbeta[i] - exp(Xbeta[i]) for i in eachindex(y)) end function logisticloss(beta,y,X) Xbeta = X*beta return sum(y[i] * log1p(exp(-Xbeta[i])) + (1 - y[i]) * log1p(exp(Xbeta[i])) for i in eachindex(y)) end rng = Random.default_rng() Random.seed!(123456) x = randn(rng,100) X = [ones(100) x] beta = [2.0, -1.0] y = rand.(rng, Poisson.(exp.(X*beta))) f(beta) = poisonloss(beta,y,X) betaini = rand(rng, 2) @time out = VItools.GDChol(f, betaini, 1.0, 1.0e-5, 2000) @time outd = VItools.GDdiag(f, betaini, 1.0, 1.0e-5, 2000) plot(out[3],label="Cholesky") plot!(outd[3],label="diagonal") png("./Desktop/loss1.png") eps = randn(rng,2,10000) betasmp = out[1].+out[2]*eps betasmpd = outd[1].+outd[2].*eps xs = sort(x) Xs = [ones(100) xs] pred = rand.(rng,Poisson.(exp.(Xs*betasmp))) predd = rand.(rng,Poisson.(exp.(Xs*betasmpd))) predmean = mean(pred,dims=2) predmeand = mean(predd,dims=2) lwr = [quantile(pred[i,:],0.025) for i in 1:100] upr = [quantile(pred[i,:],0.975) for i in 1:100] lwrd = [quantile(predd[i,:],0.025) for i in 1:100] uprd = [quantile(predd[i,:],0.975) for i in 1:100] scatter(x,y,msw=0,color="gray",legend=false) plot!(xs,predmean, color="blue") plot!(xs,predmeand, color="orange") plot!(xs,lwr,color="blue", linestyle = :dash) plot!(xs,upr,color="blue", linestyle = :dash) plot!(xs,lwrd,color="orange", linestyle = :dash) plot!(xs,uprd,color="orange", linestyle = :dash) png("./Desktop/pi1.png") #### y2 = rand.(rng,Bernoulli.(logistic.(X*beta))) f2(beta) = logisticloss(beta,y2,X) @time out2 = VItools.GDChol(f2, betaini, 1.0, 1.0e-3, 2000) @time outd2 = VItools.GDdiag(f2, betaini, 1.0, 1.0e-3, 2000) plot(out2[3],label="Cholesky") plot!(outd2[3],label="diagonal") png("./Desktop/loss2.png") betasmp2 = out2[1].+out2[2]*eps betasmpd2 = outd2[1].+outd2[2].*eps post = logistic.(Xs*betasmp2) postd = logistic.(Xs*betasmpd2) postmean = mean(post,dims=2) postmeand = mean(postd,dims=2) lwr2 = [quantile(post[i,:],0.025) for i in 1:100] upr2 = [quantile(post[i,:],0.975) for i in 1:100] lwrd2 = [quantile(postd[i,:],0.025) for i in 1:100] uprd2 = [quantile(postd[i,:],0.975) for i in 1:100] scatter(xs,y2,msw=0,color="gray",legend=false) plot!(xs,postmean, color="blue") plot!(xs,postmeand, color="orange") plot!(xs,lwr2,color="blue", linestyle = :dash) plot!(xs,upr2,color="blue", linestyle = :dash) plot!(xs,lwrd2,color="orange", linestyle = :dash) plot!(xs,uprd2,color="orange", linestyle = :dash) png("./Desktop/pi2.png") opt = optimize(f,betaini,method=BFGS(),autodiff=:forward) betahat = Optim.minimizer(opt) out[1] betahat inv(Symmetric(ForwardDiff.hessian(f,betahat))) out[2]*out[2]' opt2 = optimize(f2,betaini,method=BFGS(),autodiff=:forward) betahat2 = Optim.minimizer(opt2) out2[1] betahat2 inv(Symmetric(ForwardDiff.hessian(f2,betahat2))) out2[2]*out2[2]' function dp(a,b,y,x) exp(sum(logpdf(Poisson(exp(a+x[i]*b)),y[i]) for i in eachindex(y)) + logpdf(Normal(0.0,1.0),a) + logpdf(Normal(0.0,1.0),b)) end function dp2(a,b,y,x) exp(sum(logpdf(Bernoulli(logistic(a+x[i]*b)),y[i]) for i in eachindex(y)) + logpdf(Normal(0.0,1.0),a) + logpdf(Normal(0.0,1.0),b)) end av = 1.8:0.005:2.1 bv = -1.2:0.005:-0.9 Xv = repeat(reshape(av, 1, :), length(bv), 1) Yv = repeat(bv, 1, length(av)) Zv = map((a,b) -> dp(a,b,y,x), Xv, Yv) contour(av, bv, Zv, c=:grays, linewidth=2, colorbar=false) scatter!(betasmp[1,:], betasmp[2,:], ma=0.1, ms=3, msw=0, color=:royalblue, label="Cholesky") scatter!(betasmp[1,:], betasmpd[2,:], ma=0.1, ms=3, msw=0, color=:orange, label="diagonal") png("./Desktop/contour1.png") av = 0.5:0.01:2.5 bv = -1.5:0.01:0.5 Xv = repeat(reshape(av, 1, :), length(bv), 1) Yv = repeat(bv, 1, length(av)) Zv = map((a,b) -> dp2(a,b,y2,x), Xv, Yv) contour(av, bv, Zv, c=:grays, linewidth=2, colorbar=false) scatter!(betasmp2[1,:], betasmp2[2,:], ma=0.1, ms=3, msw=0, color=:royalblue, label="Cholesky") scatter!(betasmpd2[1,:], betasmpd2[2,:], ma=0.1, ms=3, msw=0, color=:orange, label="diagonal") png("./Desktop/contour2.png")
GDdiag は変分事後分布を独立(分散共分散行列が対角行列)とした方法です。
Julia は 1.6 を使ってます。
julia> VERSION v"1.6.3"
いくつかの結果
ポアソン回帰で負の対数尤度が小さくなっていく様子です。
95%予測区間です。
青が提案手法(コレスキー分解された分散共分散行列を持つ正規分布で事後分布を近似)オレンジが従来法(独立な正規分布で事後分布を近似)ですがほぼ一致。
ロジスティック回帰で負の対数尤度が小さくなっていく様子です。
95%信用区間です。線がぐにゃっとしていたほうがおもしろいという理由で予測区間(データのばらつき)でなく信用区間(パラメータのばらつき)をプロットしてます。
真の事後分布は尤度と事前分布の積の定数倍なので、それを等高線で描いて変分事後分布からサンプリングした乱数を散布図で重ねました。
ロジスティック回帰のほうは大差ないけど、ポアソン回帰のほうは相関を入れたほうが真の事後分布に近い形になりました。
最尤法と比較するとこんな感じ:
#ポアソン回帰 julia> betahat = Optim.minimizer(opt) 2-element Vector{Float64}: 1.9615359831503478 -1.0272951225324904 julia> out[1] 2-element Vector{Float64}: 1.9609649365590203 -1.0309899950651866 julia> betahat 2-element Vector{Float64}: 1.9615359831503478 -1.0272951225324904 julia> inv(Symmetric(ForwardDiff.hessian(f,betahat))) 2×2 Symmetric{Float64, Matrix{Float64}}: 0.00187025 0.000927429 0.000927429 0.000713584 julia> out[2]*out[2]' 2×2 Matrix{Float64}: 0.00180625 0.00101994 0.00101994 0.000909128 #ロジスティック回帰 julia> betahat2 = Optim.minimizer(opt2) 2-element Vector{Float64}: 1.4335809051433548 -0.7636614032163528 julia> out2[1] 2-element Vector{Float64}: 1.3288174019948678 -0.7083837999431308 julia> betahat2 2-element Vector{Float64}: 1.4335809051433548 -0.7636614032163528 julia> inv(Symmetric(ForwardDiff.hessian(f2,betahat2))) 2×2 Symmetric{Float64, Matrix{Float64}}: 0.0758743 -0.0259558 -0.0259558 0.0729946 julia> out2[2]*out2[2]' 2×2 Matrix{Float64}: 0.050385 -0.0102875 -0.0102875 0.0695762