ジョンとヨーコのイマジン日記

キョウとアンナのラヴラブダイアイリー改め、ジョンとヨーコのイマジン日記です。

事後分布を正規分布で近似する変分推論のアルゴリズムを Julia で書く

アルゴリズム


須山『ベイズ推論による機械学習入門』ではロジスティック回帰とニューラルネットのところで近似事後分布として1次元の正規分布を仮定して、変分パラメータを勾配法で推定するアルゴリズムが出てくる。

勾配を評価するときには再パラメータ化トリック(re-parameterization trick)というアイデアを使う。

これはニューラルネットワークに限らずどんな分布でも使えるアイデアで便利だと思ったので、まとめ直してみる。

ここではパラメータの添字は省略する。

まずすべてのパラメータ w に対して、平均0、分散 \lambda^{-1} の正規事前分布 p(w) を設定する。

\displaystyle \log p(w) =\frac{1}{2} \log \lambda - \lambda \frac{w^2}{2}+\mathrm{Const.}

すべてのパラメータの近似事後分布として、平均\mu、分散 \sigma^{2}正規分布 q(w) を設定する。

\displaystyle \log q(w) = - \log \sigma -  \frac{(w-\mu)^2}{2\sigma^2}+\mathrm{Const.}

標準偏差は正の数なので \sigma = \exp(\rho) とおいて、\rho を最適化する。

(『ベイズ推論による機械学習入門』では、\sigma = \log(1+\exp(\rho))となっているが、微分をかんたんにするため、指数関数にしてみた。)

真の事後分布とのカルバック・ライブラ距離が近い正規分布を求めるのがアルゴリズムの目的。

標準正規乱数 \varepsilon を使ってサンプル \tilde{w} = \mu + \sigma \tilde{\varepsilon} を得るとこれは、近似事後分布からのサンプルそのものになる。観測された目的変数を  Y、 説明変数を X として、尤度を  p(Y|X,w) と書くと真の事後分布と近似事後布のカルバックライブラ距離は、

 \displaystyle \operatorname{KL}(q(w)\|p(w|Y,X))\\
 \displaystyle \approx \log q(\tilde w)- \log p(\tilde w) - \log p(Y|X,\tilde w)+\mathrm{Const.}

と近似できる。

1個のサンプルの標本平均で期待値を近似しているわけで、モンテカルロEMとかを使ったことがある人は変に感じるかもしれないけれでも、これでもけっこううまくいく。もちろん複数サンプリングして平均をとってもよいが計算が大変になる。

あとはこれを微分して、勾配法で \mu\sigma を更新してやる。

\displaystyle \log q(\tilde w) = - \log \sigma -  \frac{(\tilde w - \mu)^2}{2\sigma^2}+\mathrm{Const.}\\
\displaystyle= - \log \sigma -  \frac{(\mu+\sigma \varepsilon-\mu)^2}{2\sigma^2}+\mathrm{Const.}\\
\displaystyle = - \log \sigma -  \frac{ \varepsilon^2}{2}+\mathrm{Const.}

なので、
 \displaystyle \frac{d}{d\mu}\log q(\tilde w) =0
 \displaystyle \frac{d}{d\sigma}\log q(\tilde w) = (-1/\sigma)
 \displaystyle \frac{d}{d\mu}\log p(\tilde w) =-\lambda \tilde w
 \displaystyle \frac{d}{d\sigma}\log p(\tilde w) = -\lambda \tilde w \varepsilon

また、合成関数の微分なので、
 \displaystyle \frac{d}{d\rho} \exp(\rho) =\exp(\rho)
を忘れないようにかけてやる。

まとめると次のアルゴリズムが得られる。

  1. 学習率 \alpha と 事前分布の精度パラメータ  \lambda を設定。
  2. \mu\rho を適当に初期化し以下を繰り返す。
  3. すべてのパラメータに対して、標準正規乱数 \tilde{\varepsilon} を使ってサンプル  \tilde{w} = \mu + \exp(\rho) \tilde{\varepsilon} を得る。
  4.  g = -\frac{d}{dw} \log p(Y|X, \tilde w) を計算。
  5.  g_\mu =  g + \lambda\tilde{w} とする。
  6.  g_\rho = (g \tilde{\varepsilon} + \lambda \tilde{w} \tilde{\varepsilon})\exp(\rho) -1 とする。
  7.  \mu \leftarrow \mu + \alpha g_\mu で更新
  8.  \rho \leftarrow \rho + \alpha g_\rho で更新

Julia のコード

Julia には自動微分のパッケージ ForwardDiff があるので対数尤度の微分のところは自分で計算しなくてよい(場合もある)。

今回は乱数で適当に作ったデータでポアソン回帰をやってみる。

こんなふうだ:

using Distributions
using ForwardDiff
using Random
using Plots
using Optim
using LinearAlgebra

function poisonloss(beta,y,X)
    Xbeta = X*beta
    lambda = exp.(Xbeta)
    return -sum(y .* Xbeta - lambda)
end

function VIGD(f, par0, lambda, lr, maxiter::Int) 
    rng = Random.default_rng()
    len = length(par0)
    mu = randn(rng,len)
    rho = randn(rng,len)
    g(beta) = ForwardDiff.gradient(beta0 -> f(beta0),beta)
    logloss = zeros(maxiter)
    for i in 1:maxiter
    sigma = exp.(rho)
    epsilon = randn(rng,len)
    beta = mu + sigma.*epsilon
    fx = f(beta)
    gvec = g(beta)
    g_mu = gvec + lambda*beta
    g_rho = (gvec.*epsilon + lambda*beta.*epsilon).*sigma .- 1.0
    mu = mu - lr * g_mu
    rho = rho - lr * g_rho
    logloss[i] = fx
    end
    return mu, rho, logloss
end

rng = Random.default_rng()
Random.seed!(1)
x = sort(randn(rng,100))
X = [ones(100) x]
beta = [2.0,-1.0]

y = rand.(rng,Poisson.(exp.(X*beta)))

f(beta) = poisonloss(beta,y,X)
    
betaini = [0.0,0.0]
@time μ, ρ, logloss = VIGD(f, betaini, 0.0, 1.0e-4, 5000)
#計算時間は0.2秒くらい

plot(logloss,legend=false)

ε = randn(rng,2,1000)
betasmp = μ .+ exp.(ρ).*ε 

post = exp.(X*betasmp)
pred = rand.(rng,Poisson.(post))
predmean = mean(pred,dims=2)
lwr = [quantile(pred[i,:],0.025) for i in 1:100]
upr = [quantile(pred[i,:],0.975) for i in 1:100]

scatter(x,y,legend=false)
plot!(x,predmean, color="blue")
plot!(x,lwr,color="blue", linestyle = :dash)
plot!(x,upr,color="blue", linestyle = :dash)
png("./Desktop/plot.png")

opt = optimize(f,betaini,method=BFGS(),autodiff=:forward)
β = Optim.minimizer(opt)
se = sqrt.(diag(inv(Symmetric(ForwardDiff.hessian(f,β)))))

poisonloss のところを書き換えればロジスティック回帰でもワイブル回帰でもできるはず。

予測分布の95%区間を点線でプロットした。

f:id:abrahamcow:20201229143541p:plain

計算が進むにつれて負の対数尤度が小さくなっていく様子です。

f:id:abrahamcow:20201229143941p:plain

ちなみに最尤推定の標準誤差にくらべると変分近似した事後分布の標準偏差はやや小さめに求まる。

julia> println(μ)
[1.999546299946068, -1.0194307666319837]

julia> println(β)
[2.0010916047667635, -1.0185383086714856]

julia> println(exp.(ρ))
[0.035814747235478735, 0.026642018884133416]

julia> println(se)
[0.04194236834667879, 0.031554017967900964]

とはいえ、平均はともかく標準偏差のほうはイテレーションの回数によってけっこう変わる。

Albert (2008) 打者の調子の波のモデル化

モチベーション

以前にAlbert (2008)、
https://www.stat.berkeley.edu/~aldous/157/Papers/albert_streaky.pdf
を読んでやってみたことのJulia版です。

次の0と1の羅列はカルロス・ギーエンカルロス・ギーエン - Wikipedia)という選手の2005年の打撃成績のデータで、ヒットを 1、アウトを 0 とコード化してある。

0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0,
1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0,
0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1,
0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0,
0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0,
0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1

年間通しての打率は約 3 割 2 分。

上記のデータの20打席ごとの移動平均をとってみる。

ぼくは野球に詳しくないのだけれど、20 はおよそ 4 試合ごとの打席数とのこと。

f:id:abrahamcow:20201223001549p:plain

こうして見てみるとある時期には 6 割を超えていたり、またある時期にはほとんど 0 に近かったりする。

これはギーエンという打者の「調子の波」だと言っていいだろうか。

パラメトリックブートストラップによる検定

調子の波が存在しない選手は、常にコンスタントな打率でヒットを出すから、その打撃成績を上記のように 0 か 1 かに符号化すると、それはベルヌーイ過程になる。

このコンスタントバッターを乱数でシミュレーションしてみる。

f:id:abrahamcow:20201223001759p:plain

シミュレーションの方にもそれなりに意味ありげな波ができている気がするが、波の幅はギーエンのデータよりも狭い感じがする。

こうなってくるとちゃんと検定したくなるのが人情だろう。

帰無仮説:「ギーエンの0-1のプロセスがベルヌーイ過程である。」

対立仮説:「ギーエンの0-1のプロセスがベルヌーイ過程でない。」

モチベーションとなった例で出した移動平均から検定統計量を作ってみる。

  • 移動平均のレンジ(the range of the moving averages)  R = \max_j m_j − \min_j m_j
  • 移動平均のシーズン平均からの平均変動(the mean variation of the moving averages about the season average)

\displaystyle B =\frac{1}{n-w-1}\sum^{n-w+1}_{j=1}|m_j-\bar y|

m_j移動平均w移動平均のウィンドウサイズ、n はトータルの打席数である。

この B を Albert (2008) では "black" 統計量と呼んでいるが、なぜ "black" なのかはよくわからなかった。

帰無分布はシミュレーションで作れるので、シミュレートされた B_i と観測された B を比べてやれば p 値が出る。

 p = \sum_{i=1}^M I(B_i > B)/M

M はシミュレーションの試行回数、 I(\cdot) は中の不等号を満たすときに 1 そうでなければ 0 の値を返す関数である。

検定統計量の分布をヒストグラムで示す。

まず R

f:id:abrahamcow:20201223003418p:plain

次に B
f:id:abrahamcow:20201223003420p:plain

p 値はそれぞれ 0.0262、0.0286 で 5%水準で有意になった。

ベータ二項モデル

パラメトリックブートストラップによる検定からはギーエンには調子の波があるといえそうなことはわかった。

ただし B と R は打率や打席数に依存するため、選手間で調子の波の強さを比較したりはできない。

そこでベータ二項モデルを導入してギーエンのヒットの数をモデル化することを考える。

ギーエンのデータを 20 打席ごとに区切って次のように集計する。

ヒットの数 打席数
5 20
5 20
7 20
10 20
10 20
10 20
6 20
9 20
4 20
4 20
6 20
7 20
4 20
2 20
6 20
5 20
5 14

ヒットの数は二項分布に従うと仮定する。

20打席ごとの二項分布の成功確率 p の事前分布にベータ分布を置く。

ここでの工夫はベータ分布を以下のようにパラメタライズすること。

\displaystyle \frac{1}{B(K\eta,K(1-\eta))}p^{K\eta-1}(1-p)^{K(1-\eta)-1}

こうすると \eta は打率に対応するパラメータ、K は打率の集中度に関するパラメータと解釈できる。

K が大きいほどばらつきが小さくなる。

Albert (2008) に習い K\eta に対して次の事前分布を仮定する。

\displaystyle g(\eta, K) \propto \frac{1}{\eta (1-\eta)}\frac{1}{(1+K)^2}

ちなみにこの事前分布を外すと、MCMC がまったく収束しなかった(と昔の日記に書いてある)。

推定には Stan を使った。

Stan のコードはこう。

data {
  int<lower=0> N;
  int<lower=0> m;
  int<lower=0> n[m];
  int<lower=0> x[m];
}
parameters {
  real<lower=0, upper=1> p[m];
  real<lower=0> K;
  real<lower=0, upper=1> eta;
}
model {
  x ~ binomial(n, p);
  p ~ beta(K*eta, K*(1-eta));
  target += -(log(eta)+log1m(eta))-2*log(K);
}
generated quantities{
  int<lower=0,upper=1> pred[N];
  for(i in 1:N){
    pred[i] = bernoulli_rng(p[i/20+1]);
  }
}

K\eta の事後分布のトレースプロットです。

まずK
f:id:abrahamcow:20201223005731p:plain

次に\eta
f:id:abrahamcow:20201223005757p:plain

MCMC を回すのと同時に generated quantities ブロックで0-1の予測分布を生成していた。

予測分布でもパラメトリックブートストラップによる検定と同様のことができる。

統計量の分布をヒストグラムで示す。

まず R
f:id:abrahamcow:20201223004836p:plain

次に B
f:id:abrahamcow:20201223004834p:plain

p 値はそれぞれ 0.1789、0.2906 で 5%水準では棄却されない。

つまりモデルはそこそこ当てはまっていることが伺える。

最後に予測分布から計算した移動平均の平均をプロットしてみる。

f:id:abrahamcow:20201223005316p:plain

Julia のコード

using Plots
using Distributions
using Random
using CmdStan
set_cmdstan_home!(homedir()*"/projects/cmdstan")

modBetaBinom = Stanmodel(name="BetaBinom", model=open(f->read(f, String), "./Documents/BetaBinom.stan"), nchains=4, num_warmup=2500,num_samples=2500)

GuillenC = [0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0,
1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0,
0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1,
0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0,
0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0,
0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1]

moving_average(vs,w) = [sum(vs[i:(i+w-1)])/w for i in 1:(length(vs)-(w-1))]

N = length(GuillenC)
avg = mean(GuillenC)
ma = moving_average(GuillenC,20)
plot(ma,legend=false)
hline!([avg])
png("./Desktop/ma.png")

rng = Random.default_rng()
Random.seed!(1)
y = rand(Bernoulli(avg),N)
ma_c = moving_average(y,20)
plot(ma, label="obs",legend=:top)
plot!(ma_c, label="sim")
png("./Desktop/ma_sim.png")

Bobs = mean(abs,ma.-avg)
Robs = maximum(ma)-minimum(ma)
function bootBlack(M,avg)
    B_boot = zeros(M)
    R_boot = zeros(M)
    for i in 1:M
    y = rand(Bernoulli(avg),N)
    ma_c = moving_average(y,20)
    B_boot[i] = mean(abs,ma_c.-avg)
    R_boot[i] = maximum(ma_c)-minimum(ma_c)
    end
    return B_boot, R_boot
end
Random.seed!(1)
B_boot, R_boot = bootBlack(10000,avg)
p_B = mean(B_boot.>Bobs)
p_R = mean(R_boot.>Robs)

histogram(R_boot,legend=false)
vline!([Robs])
png("./Desktop/boot_R.png")

histogram(B_boot,legend=false)
vline!([Bobs])
png("./Desktop/boot_B.png")

m = div(N,20)
x = zeros(Integer,m+1)
n = zeros(Integer,m+1)
j = 0
for i in 1:m
  x[i] = sum(GuillenC[(j+1):(j+20)])
  n[i] = 20
  j += 20
end
x[m+1] = sum(GuillenC[(j+1):N])
n[m+1]= N-j
println(x)

dat = Dict("N" => N, "m" => m+1, "n" => n, "x" => x)
rc, samples, cnames = stan(modBetaBinom, [dat])

K=[samples[:,cnames.=="K",i] for i in 1:4]
plot(K,legend=false)

eta=[samples[:,cnames.=="eta",i] for i in 1:4]
plot(eta,legend=false)

predpos = [occursin("pred.",cnames[i]) for i in eachindex(cnames)]
y_pred =[samples[:,predpos,1];samples[:,predpos,2];samples[:,predpos,3];samples[:,predpos,4]]
function predBlack(ypred,avg)
    M = size(y_pred,1)
    B_boot = zeros(M)
    R_boot = zeros(M)
    for i in 1:M
    ma_c = moving_average(y_pred[i,:],20)
    B_boot[i] = mean(abs,ma_c.-avg)
    R_boot[i] = maximum(ma_c)-minimum(ma_c)
    end
    return B_boot, R_boot
end

B_pred, R_pred = predBlack(y_pred,avg)

histogram(R_pred,legend=false)
vline!([Robs])
png("./Desktop/pred_R.png")

histogram(B_pred,legend=false)
vline!([Bobs])
png("./Desktop/pred_B.png")

mean(B_pred.>Bobs)
mean(R_pred.>Robs)

ma_pred = [moving_average(y_pred[i,:],20) for i in 1:10000]
ma_predmean = mean(ma_pred,dims=1)
plot(ma,legend=false)
plot!(ma_predmean,legend=false)
png("./Desktop/pred_ma_mean.png")

CmdStanのインストールは
1 CmdStan Installation | CmdStan User’s Guide
を参照。

set_cmdstan_home!(homedir()*"/projects/cmdstan")

の行は自分の環境のCmdStanのディレクトリのパスを入れること。

ぼくは毎回これを書いてるけど、もっとうまいやり方がありそうな気はする。