ジョンとヨーコのイマジン日記

キョウとアンナのラヴラブダイアイリー改め、ジョンとヨーコのイマジン日記です。

変な形の事後分布のアニメーション

ハイパーパラメータによって事後分布の形が大きく変わる様子です.

変な形の尤度関数をプロットする - ジョンとヨーコのイマジン日記 の続き的な投稿.

例1: 混合正規分布

モデルとして次の分布を考える.

\displaystyle p(y; a, b) = (1-a)\mathcal{N}(y; 0,1) + a\mathcal{N}(y; 0,b)

ここで \mathcal{N}(y; 0,1) は平均0, 分散1 の正規分布の密度関数, 0 \le a \le 1 とする.

さらに事前分布として, パラメータ a にベータ分布 B(a;\theta,\theta),  bに適当な幅の一様分布(密度が定数)を設定する.

この手のモデルは2つの分布が混ざったような形になるので混合分布とよばれ, モデルに基づくクラスタリングなどの際によく使われる.

ただし未知パラメータ(データから推定したいパラメータ)を2個にしてるのは単に図を作りやすくする都合のほうが大きい.

サンプルの大きさを100として, このときの事後分布をプロットしてみる. \theta ごとに, 最大値が1(黄色っぽい色)になるようスケーリングしてある.

\theta が小さいときはサポートの端っこに密度が張り付いていたようになっていたのが大きくなると真ん中によってくる様子.

例2: 非線形回帰

モデルとして次の分布を考える.

\displaystyle p(y; a, b, x) = \mathcal{N}(y, b\tanh(a x),1)

 x は既知の数とする.

事前分布として, パラメータ a,b正規分布 \mathcal{N}(a;0,\theta^2), \mathcal{N}(b;0,\theta^2) を設定する.

例えば  x を時間として, なんらかの施策の前後でデータの変化を見たいときは(データを適当にスケーリングするとして)このようなモデルを考えるかもしれない.

ちなみに一番単純なニューラルネットワークではこのような非線形の変換をベクトルに対して繰り返して行い, 例えば  \mathcal{N}( y; \tanh (\tanh(x w_1) w_2) w_3, 1) のようなモデルを考える.

例1のときと同様, サンプルの大きさを100として, このときの事後分布をプロットしてみる.

R のコード

まとめて貼ります.

library(dplyr)
library(ggplot2)
library(gganimate)

logsumexp2 =function (logx1,logx2){
  logx1 + log1p(exp(logx2-logx1))
}

logsumexp <- function(x){
  mx <- max(x)
  mx + log(sum(exp(x-mx)))
}

llmixnorm <- function(par, y){
  a0 <- par[1]
  b0 <- par[2]
  theta <- par[3]
  sum(logsumexp2(log(1-a0)+dnorm(y,log = TRUE), log(a0)+dnorm(y,b0,log = TRUE))) + dbeta(a0, theta, theta, log = TRUE)
}

llcp <- function(par, y, x){
  a0 <- par[1]
  b0 <- par[2]
  theta <- par[3]
  sum(dnorm(y,a0*tanh(b0*x),log = TRUE)) +
    dnorm(a0, 0, theta, log = TRUE)+
    dnorm(b0, 0, theta, log = TRUE)
}

N <- 100L
set.seed(1)
y11 <- c(rnorm(N/2),rnorm(N/2,0.5))

ggplot()+
  geom_histogram(data=NULL, aes(x=y11), fill="grey70", bins=25)+
  theme_bw(14)+labs(x="y", y="count")

parms <- expand.grid(a=seq(0.01,0.99,length.out = 60),
                     b=seq(0,1,length.out = 60),
                     theta=seq(0.5,2,by=0.1))

l11 <- apply(parms, 1, llmixnorm, y=y11)
dfL11 <- data.frame(parms, value = exp(l11)) %>% 
  group_by(theta) %>% 
  mutate(value=value/max(value)) %>% 
  ungroup()

ggplot(dfL11,aes(x=b,y=a))+
  geom_point(aes(colour=value), pch=15, size=3.5)+
  scale_colour_continuous(type = "viridis")+
  labs(title = 'theta: {round(frame_time,1)}') +
  transition_time(theta)+
  theme_bw(14)+theme(legend.position = "none")

anim_save("post_mixnorm.gif")

####
#tanh

parms <- expand.grid(a=seq(-2,2,length.out = 60),
                     b=seq(-10,10,length.out = 60),
                     theta=seq(1.5,0.1,by=-0.1))

x <- seq(-1,1,length.out=N)
set.seed(1);y11 <- c(rnorm(N/2,-.5),rnorm(N/2,0.5))

ggplot(data = NULL)+
  geom_point(aes(x=x, y=y11))+
  theme_bw(14)

l11 <- apply(parms, 1, llcp, y=y11,x=x)
dfL11 <- data.frame(parms, value = exp(l11)) %>% 
  group_by(theta) %>% 
  mutate(value=value/max(value)) %>% 
  ungroup()


ggplot(dfL11,aes(x=b,y=a))+
  geom_point(aes(colour=value),pch=15, size=3.5)+
  scale_colour_continuous(type = "viridis")+
  labs(title = 'theta: {round(frame_time,2)}') +
  transition_time(theta)+
  theme_bw(14)+theme(legend.position = "none")

anim_save("post_tanh.gif")