AIエンジニア/データサイエンティストが使うベイズ統計モデル

Python3 PyMC3 によるMCMC(Markov chain Monte Carlo)

階層ベイズモデル(hierarchical Bayesian model)

トップページに戻る

Jupyter Notebook の ipynb ファイルをダウンロード

この記事は,次の本を参照しています。

In [1]:
from IPython.display import Image
Image("./image_common/book_Martin.jpg")
Out[1]:

『Pythonによるベイズ統計モデリング』PyMCでのデータ分析実践ガイド  Osvaldo Martin (原著), オズワルド マーティン (著), 金子 武久 (翻訳)

In [2]:
from IPython.display import Image
Image("./image_common/book_Davidson-Pilon.jpg")
Out[2]:

『Pythonで体験するベイズ推論』 ーPyMCによるMCMC入門ー キャメロン デビッドソン=ピロン (著), 玉木 徹 (翻訳)

階層ベイズモデルの事前分布のハイパーパラメータ

Martin著で階層ベイズモデルのコードを調べていたら,何の説明もなく HalfCauchy 分布が出てきて戸惑いましたので,Webで調べました。

詳細はあとで説明しますが,階層ベイズモデルではない普通のベイズモデルはコイントスを例にしますと次のような構造になっています。

ベイズ定理をベイズモデルとして解釈します。

$$ P\left(A|B\right) = \dfrac {P\left(B|A\right)P\left(A\right)}{P\left(B\right)}\propto {P\left(B|A\right)P\left(A\right)} $$

$P\left(B|A\right)$;尤度

$P\left(A\right)$;事前分布

$P\left(A|B\right)$;事後分布

詳しくは説明しませんが,コイントスの尤度の候補は,2項分布です。事前分布の候補は,ベータ分布です。この2つの関係は共役事前分布であり,事後分布もベータ分布となります。

2項分布の1回投げがベルヌーイ分布となります。

ベイズモデルは次のようになります。

$$\theta \sim Beta(\alpha, \beta)$$ $$y \sim Bern(\theta)$$

ベータ分布のパラメータは,$\alpha = \beta = 1$ として,一様分布とすると,無情報事前分布となります。

階層ベイズモデルは,個体差,場所差,グループ差などを扱います。

階層ベイズモデルの事前分布のパラメータは定数ではなく,ハイパーパラメータの確率分布となります。

階層ベイズモデルは次のように書けます。

$$\alpha \sim HalfCauchy(\beta_{\alpha})$$$$\beta \sim HalfCauchy(\beta_{\beta})$$$$\theta \sim Beta(\alpha, \beta)$$$$y \sim Bern(\theta)$$

ハイパーパラメータの確率分布の選択について,コロンビア大学のAndrew Gelman教授の有名な論文があります。R言語のStanの開発者でもあります。

上の論文の日本語解説が次のサイトにあります。

これによりますと,グループ数が5を超える場合は,一様分布 Uniform(0,A)(Aは十分大きく),グループ数が3~5の場合は,半コーシ分布が良いとなっています。

逆ガンマ分布は決して使わないようにとしています。

階層ベイズモデルのKruschkeダイヤグラムは次のようになります。

In [3]:
from IPython.display import Image
Image("./image_common/book_Martin_fig1.png")
Out[3]:
In [4]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import seaborn as sns
import pymc3 as pm
import pandas as pd
plt.style.use('seaborn-darkgrid')
np.set_printoptions(precision=2)
pd.set_option('display.precision', 2)

入力データ

仮想の上水道の水質を推定するモデルを考えます。

水域は3つとします。各水域の水質を推定するとともに,都市全体の水質も推定するモデルを考えます。

N_samplesは各水域のサンプル数です。

G_samplesは各水域の合格数です。

group_idxは尤度ベルヌーイ分布に与えるベータ分布の出力に水域のインデックスを付けます。

dataは尤度ベルヌーイ分布に与えるベータ分布の出力の初期値です。G_samplesの合格数を反映しています。

In [5]:
N_samples = [30, 30, 30]
# G_samples = [18, 18, 18]
G_samples = [18, 15, 12]

group_idx = np.repeat(np.arange(len(N_samples)), N_samples)
data = []
for i in range(0, len(N_samples)):
    data.extend(np.repeat([1, 0], [G_samples[i], N_samples[i]- G_samples[i]]))


print(group_idx)
print(data)
print(len(group_idx), len(data))
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
90 90

モデルの構築

モデルを構築します。データの構造からコイントスと同じになります。事前分布の入力にハイパー事前分布があります。

$$\alpha \sim HalfCauchy(\beta_{\alpha})$$ $$\beta \sim HalfCauchy(\beta_{\beta})$$ $$\theta \sim Beta(\alpha, \beta)$$ $$y \sim Bern(\theta)$$

with pm.Model()のインデントにモデルを書きます。

alphaとbetaはハイパー事前分布HalfCauchyの出力です。

thetaは事前分布Betaの出力です。入力はalphaとbetaです。shapeは水域が3つあることを示しています。

yは尤度ベルヌーイ分布の出力です。入力はthetaですが,[0, 1, 2]のインデックスが付いていて,水域を示しています。あとでそれぞれのthetaを表示できます。

yはサンプルとしてpm.sampleの出力の数列となります。これをtraceとかchainとか呼びます。

pm.sampleの前で,start = で,pm.find_MAPを呼び出すことができますが,ここでは省略されています。

pm.sampleの前で,step = で,MCMCの種類を指定することができますが,ここでは省略されていますので,最適なものが自動で選択されます。ここで選択されたものは,NUTS です。

In [6]:
with pm.Model() as model_h:
    alpha = pm.HalfCauchy('alpha', beta=10)
    beta = pm.HalfCauchy('beta', beta=10)
    theta = pm.Beta('theta', alpha, beta, shape=len(N_samples))
    y = pm.Bernoulli('y', p=theta[group_idx], observed=data)

    trace_h = pm.sample(2000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [theta, beta, alpha]
Sampling 4 chains, 324 divergences: 100%|????????????????????????????????????| 10000/10000 [00:16<00:00, 614.66draws/s]
There were 73 divergences after tuning. Increase `target_accept` or reparameterize.
There were 89 divergences after tuning. Increase `target_accept` or reparameterize.
There were 101 divergences after tuning. Increase `target_accept` or reparameterize.
The acceptance probability does not match the target. It is 0.6792393089435075, but should be close to 0.8. Try to increase the number of tuning steps.
There were 61 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 10% for some parameters.

バーンインとtraceplot

バーンインとして最初の200サンプルを削除します。

ここでサンプルという言葉を使いましたが,シミュレーション1周で90のデータを出力しますので,200周のデータを削除するという意味です。

出力された図の左側は事前分布の出力パラメータ alpha, beta, thetaのカーネル密度推定(KDE)です。thetaは明確に3つのグループに分かれているのが読み取れます。

右側の図はサンプル値であり,収束の判定に使います。全体の動きとして上や下へ向かうのは収束していません。またミクロに見て,正規分布のように見えれば正常です。

In [7]:
chain_h = trace_h[200:]
pm.traceplot(chain_h)

# plt.savefig('img314.png', dpi=300, figsize=(5.5, 5.5))

plt.figure()
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
C:\Users\yamak\Anaconda3\lib\site-packages\arviz\plots\backends\matplotlib\distplot.py:38: UserWarning: Argument backend_kwargs has not effect in matplotlib.plot_distSupplied value won't be used
  "Argument backend_kwargs has not effect in matplotlib.plot_dist"
Out[7]:
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>

summary

事前分布の出力パラメータ alpha, beta, theta のsummaryです。theta については3つに分けられています。

In [8]:
pm.summary(chain_h)
Out[8]:
mean sd hpd_3% hpd_97% mcse_mean mcse_sd ess_mean ess_sd ess_bulk ess_tail r_hat
alpha 17.42 18.13 1.22 42.36 1.79e+00 1.43e+00 102.0 81.0 372.0 129.0 1.02
beta 17.63 18.49 1.01 43.52 1.83e+00 1.47e+00 102.0 80.0 369.0 120.0 1.02
theta[0] 0.55 0.08 0.41 0.69 1.00e-03 1.00e-03 2648.0 2648.0 2628.0 3456.0 1.00
theta[1] 0.50 0.07 0.36 0.63 1.00e-03 1.00e-03 3056.0 3056.0 3064.0 3882.0 1.00
theta[2] 0.45 0.07 0.30 0.57 1.00e-03 1.00e-03 3425.0 3377.0 3428.0 3757.0 1.00

theta の事前分布を出力する

In [9]:
x = np.linspace(0, 1, 100)
for i in np.random.randint(0, len(chain_h), size=100):
    pdf = stats.beta(chain_h['alpha'][i], chain_h['beta'][i]).pdf(x)
    plt.plot(x, pdf, 'g', alpha=0.05)

dist = stats.beta(chain_h['alpha'].mean(), chain_h['beta'].mean())
pdf = dist.pdf(x)
mode = x[np.argmax(pdf)]
mean = dist.moment(1)
plt.plot(x, pdf, label='mode = {:.2f}\nmean = {:.2f}'.format(mode, mean))

plt.legend(fontsize=14)
plt.xlabel(r'$\theta_{prior}$', fontsize=16)
# plt.savefig('img315.png', dpi=300, figsize=(5.5, 5.5))

plt.figure()
Out[9]:
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>

Gelman-Rubinテスト

数量的に収束をチェックします。1.1未満なら良好とします。

In [10]:
pm.gelman_rubin(chain_h)
C:\Users\yamak\Anaconda3\lib\site-packages\pymc3\stats\__init__.py:43: UserWarning: gelman_rubin has been deprecated. In the future, use rhat instead.
  warnings.warn("gelman_rubin has been deprecated. In the future, use rhat instead.")
Out[10]:
<xarray.Dataset>
Dimensions:      (theta_dim_0: 3)
Coordinates:
  * theta_dim_0  (theta_dim_0) int32 0 1 2
Data variables:
    alpha        float64 1.016
    beta         float64 1.015
    theta        (theta_dim_0) float64 1.001 1.001 1.002

theta のHPD

In [11]:
pm.forestplot(chain_h, varnames=['theta'])
C:\Users\yamak\Anaconda3\lib\site-packages\pymc3\plots\__init__.py:21: UserWarning: Keyword argument `varnames` renamed to `var_names`, and will be removed in pymc3 3.8
  warnings.warn('Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8'.format(old=old, new=new))
Out[11]:
array([<matplotlib.axes._subplots.AxesSubplot object at 0x000001D4DF676400>],
      dtype=object)

自己相関のチェック

In [12]:
pm.autocorrplot(chain_h)
plt.figure()
Out[12]:
<Figure size 432x288 with 0 Axes>
<Figure size 432x288 with 0 Axes>

alpha, beta, theta の事後分布

In [13]:
pm.plot_posterior(chain_h, kind='kde')
Out[13]:
array([<matplotlib.axes._subplots.AxesSubplot object at 0x000001D4E30A0C88>,
       <matplotlib.axes._subplots.AxesSubplot object at 0x000001D4E307C668>,
       <matplotlib.axes._subplots.AxesSubplot object at 0x000001D4E306B7B8>,
       <matplotlib.axes._subplots.AxesSubplot object at 0x000001D4E3E9D160>,
       <matplotlib.axes._subplots.AxesSubplot object at 0x000001D4E30F4BA8>],
      dtype=object)
In [ ]:

inserted by FC2 system