Obey Your MATHEMATICS.

機械学習関連の純粋数学や実験など

アヒル本のモデルをいくつかPyMCで実装しました. 詰まったところ.

今話題のアヒル本

StanとRでベイズ統計モデリング (Wonderful R)

StanとRでベイズ統計モデリング (Wonderful R)


の後半の方にあるモデルをいくつかPyMC3で実装しました:

github.com


特に一番重要であろうChapter8は全部実装してあります。

間違いや、こうしたらどうですか?みたいなコメントあったらTwitterまで御連絡ください。


この本は”統計モデリングとは”から始まり懇切丁寧にベイズ統計の実践方法を解説してあり、とてもためになります。

データ分析プロセス↓に並んで、Rユーザー以外にもおすすめ出来ます。*1

データ分析プロセス (シリーズ Useful R 2)

データ分析プロセス (シリーズ Useful R 2)



と、言う報告で終わらせたかったのですがPyMC3で実装していて詰まった所をメモがてら残しておこうと思います。


§1. クラス値ベクトルの代入

ahiru_book_pymc/model8-2_pymc.py at master · mathetake/ahiru_book_pymc · GitHub
model 8-2 について考えます。
ここで考えるデータの行は会社員一人一人のサンプルに対応していて、
各サンプルは変数として年齢-給与-会社 と言う変数を持っています。

そしてここでは、
会社(Class)ごとに、給与(Y)年齢(X)を説明変数として線形回帰するものです。

コードを見てみましょう:

df = pd.read_csv("data-salary-2.txt")
X_data = df.values[:,0]
Y_data = df.values[:,1]
Class_data  = df.values[:,2]-1
n_Class = len(df["KID"].unique())


basic_model = Model()

with basic_model:
    a = Normal('a', mu=0, sd=10, shape=n_Class)
    b = Normal('b', mu=0, sd=10, shape=n_Class)
    epsilon = HalfNormal('sigma', sd=1)

    #likelihood 
    mu = a[Class_data] + b[Class_data]*X_data
    Y_obs = Normal('Y_obs', mu=mu, sd=epsilon, observed=Y_data)
    trace = sample(2000)
    summary(trace)

さて、会社毎(Class)に切片と傾きを求める必要があります。つまりn_Class個の切片と傾きがいるわけです。

ですのでまず a(切片) とb(傾き) に対する事前分布に shape=n_Classを渡しています。
ここまでは良いでしょう。

そのあと尤度を計算する際に

a[Class_data]
b[Class_data]

と言う記述があります。Class_data大きさがn_data のnumpyベクトルで、
会社のクラスの値(0~n_Class-1)を取っています。*2

このベクトルを a や b に代入すると、shape = n_data のvariableになり、各成分がクラスに対応した確率変数になっています。

うーーーん。難しい。いや、理解はしてるんですが、なんとなく気持ち悪い。

開発者のブログポストのコメント欄でも、counterintuitiveであると本人が述べています:

The Best Of Both Worlds: Hierarchical Linear Regression in PyMC3





model11-8はもっと複雑で、上述の操作の行列に代入バージョンをやってます。。。。


§2. 変化点検出の収束が異常に遅い


model12-7はコーシー分布を使って変化点検出をするんですが、収束が異常に遅い。

1/4のデータ数でなんとかサンプリングしましたが、収束してるのかどうか怪しい…。

パラメータのサマリ↓
f:id:mathetake:20170122000323p:plain

実際のデータと予測のプロット↓
f:id:mathetake:20170122000331p:plain

どうにかしたいけど………。

しっかりと逆関数法を使って再パラメータ化もしてますが……。


追記:

著者の@berobero11さんからリプをもらいました:

たぶん原因はこれです. 気が向いたらまた実験します.

*1:僕自身Rユーザーではありません。

*2:0~n_Class-1 に直すためにデータを取り出した段階で -1 しています。1~n_Classだとたしかエラーがでました。