アルゴリズム
須山『ベイズ推論による機械学習入門』ではロジスティック回帰とニューラルネットのところで近似事後分布として1次元の正規分布を仮定して、変分パラメータを勾配法で推定するアルゴリズムが出てくる。
勾配を評価するときには再パラメータ化トリック(re-parameterization trick)というアイデアを使う。
これはニューラルネットワークに限らずどんな分布でも使えるアイデアで便利だと思ったので、まとめ直してみる。
ここではパラメータの添字は省略する。
まずすべてのパラメータ に対して、平均0、分散 の正規事前分布 を設定する。
すべてのパラメータの近似事後分布として、平均、分散 の正規分布 を設定する。
標準偏差は正の数なので とおいて、 を最適化する。
(『ベイズ推論による機械学習入門』では、となっているが、微分をかんたんにするため、指数関数にしてみた。)
真の事後分布とのカルバック・ライブラ距離が近い正規分布を求めるのがアルゴリズムの目的。
標準正規乱数 を使ってサンプル を得るとこれは、近似事後分布からのサンプルそのものになる。観測された目的変数を 、 説明変数を として、尤度を と書くと真の事後分布と近似事後布のカルバックライブラ距離は、
と近似できる。
1個のサンプルの標本平均で期待値を近似しているわけで、モンテカルロEMとかを使ったことがある人は変に感じるかもしれないけれでも、これでもけっこううまくいく。もちろん複数サンプリングして平均をとってもよいが計算が大変になる。
あとはこれを微分して、勾配法で と を更新してやる。
なので、
また、合成関数の微分なので、
を忘れないようにかけてやる。
まとめると次のアルゴリズムが得られる。
- 学習率 と 事前分布の精度パラメータ を設定。
- と を適当に初期化し以下を繰り返す。
- すべてのパラメータに対して、標準正規乱数 を使ってサンプル を得る。
- を計算。
- とする。
- とする。
- で更新
- で更新
Julia のコード
Julia には自動微分のパッケージ ForwardDiff があるので対数尤度の微分のところは自分で計算しなくてよい(場合もある)。
今回は乱数で適当に作ったデータでポアソン回帰をやってみる。
こんなふうだ:
using Distributions using ForwardDiff using Random using Plots using Optim using LinearAlgebra function poisonloss(beta,y,X) Xbeta = X*beta lambda = exp.(Xbeta) return -sum(y .* Xbeta - lambda) end function VIGD(f, par0, lambda, lr, maxiter::Int) rng = Random.default_rng() len = length(par0) mu = randn(rng,len) rho = randn(rng,len) g(beta) = ForwardDiff.gradient(beta0 -> f(beta0),beta) logloss = zeros(maxiter) for i in 1:maxiter sigma = exp.(rho) epsilon = randn(rng,len) beta = mu + sigma.*epsilon fx = f(beta) gvec = g(beta) g_mu = gvec + lambda*beta g_rho = (gvec.*epsilon + lambda*beta.*epsilon).*sigma .- 1.0 mu = mu - lr * g_mu rho = rho - lr * g_rho logloss[i] = fx end return mu, rho, logloss end rng = Random.default_rng() Random.seed!(1) x = sort(randn(rng,100)) X = [ones(100) x] beta = [2.0,-1.0] y = rand.(rng,Poisson.(exp.(X*beta))) f(beta) = poisonloss(beta,y,X) betaini = [0.0,0.0] @time μ, ρ, logloss = VIGD(f, betaini, 0.0, 1.0e-4, 5000) #計算時間は0.2秒くらい plot(logloss,legend=false) ε = randn(rng,2,1000) betasmp = μ .+ exp.(ρ).*ε post = exp.(X*betasmp) pred = rand.(rng,Poisson.(post)) predmean = mean(pred,dims=2) lwr = [quantile(pred[i,:],0.025) for i in 1:100] upr = [quantile(pred[i,:],0.975) for i in 1:100] scatter(x,y,legend=false) plot!(x,predmean, color="blue") plot!(x,lwr,color="blue", linestyle = :dash) plot!(x,upr,color="blue", linestyle = :dash) png("./Desktop/plot.png") opt = optimize(f,betaini,method=BFGS(),autodiff=:forward) β = Optim.minimizer(opt) se = sqrt.(diag(inv(Symmetric(ForwardDiff.hessian(f,β)))))
poisonloss のところを書き換えればロジスティック回帰でもワイブル回帰でもできるはず。
予測分布の95%区間を点線でプロットした。
計算が進むにつれて負の対数尤度が小さくなっていく様子です。
ちなみに最尤推定の標準誤差にくらべると変分近似した事後分布の標準偏差はやや小さめに求まる。
julia> println(μ) [1.999546299946068, -1.0194307666319837] julia> println(β) [2.0010916047667635, -1.0185383086714856] julia> println(exp.(ρ)) [0.035814747235478735, 0.026642018884133416] julia> println(se) [0.04194236834667879, 0.031554017967900964]