こんにちは、リブセンスでデータサイエンティストをしている北原です。今回も以前の記事に続き、分散共分散行列のベイズ推定を扱います。今回は、逆Wishart分布を事前分布として分散共分散行列を推定するときに生じる問題を取り上げます。分散共分散行列の事前分布としては逆Wishart分布が有名ですが、扱う問題によっては事後分布にバイアスが生じ不適切な推定結果となります。どのようなときに問題が生じるかを説明し、実際に推定して確認します。バイアス問題の回避方法については別記事にて紹介します。利用言語はRとStanです。相関係数推定のStanコードについては以下の過去記事もご参照ください。
逆Wishart分布
まず、逆Wishart分布について簡単に説明します。
逆Wishart分布は、Wishart分布に従う行列の逆行列が従う分布です。行列のサイズを$K$としたとき、対称正定値行列$\mathbf {\Sigma}$の逆Wishart分布は、自由度パラメータ$\nu \in (K−1, \infty)$と対称正定値行列のパラメータ$S \in \mathbb{R}^{K \times K}$を使って
となります。多変量だと分布の形状を正確に把握するのは難しいのですが、逆Wishart分布を1変量にしたものが逆カイ二乗分布であることから、裾が長く0に近づくと急激に0に近づくことがイメージできます。
逆Wishart分布は、分散共分散行列の共役事前分布として有名で、よく利用されています。分散共分散行列の共役事前分布であるため、ギブスサンプリングとの相性がよいです。パラメータを$\nu = K + 1$、$S$を対角行列とすると相関の周辺分布が一様分布になります。そのため、逆Wishart分布を事前分布として使うときは$\nu = K + 1$とし$S$を単位行列とすることが多いです。
逆Wishart事前分布のバイアス
分散共分散行列の事前分布に逆Wishart分布を使うことによる問題はいくつもの指摘がなされてきました。Alvarez et al. (2014)では、これらの問題を次のようにまとめています。
- 分散のモデリングの柔軟性が低く、事前知識を組み込みにくい
- 事後分布にバイアスが生じる
- 相関と分散に依存性が生じる
モデリングの柔軟性が低かったり、相関と分散に依存性が生じたりするのは、単一の自由度パラメータで分布の特性を制御しているためです。事後分布にバイアスが生じるのは0近傍での確率密度が極端に低くなるためです。逆Wishart分布を逆カイ二乗分布を多変量に拡張したものとして考えると、0に近づくにつれて確率密度が急減することがイメージできます。
Alvarez et al. (2014)では、シミュレーション研究で、逆Wishart分布を分散共分散行列の事前分布に使ってはいけないケースを明らかにしています。
- 推定対象の標準偏差が事前分布の平均より相対的に小さいとき、事後分布にバイアスが生じる
- 相関は0に近づくよう過小に推定される
- 標準偏差は過大に推定される
- サンプルサイズが大きくなるとバイアスの影響は小さくなるが残る
つまり、推定対象のばらつきが小さいと予想されるケースでは逆Wishart分布を事前分布として利用するのは適切でないと考えられます。
バイアスの確認
では、実際にStanで相関係数を推定してバイアスが生じるかを確認してみましょう。cmdstanのバージョンは2.31.0です。ここでは相関係数を0.7に固定し、標準偏差の大きさを3パターン(1、0.1、0.01)調べます。サンプルデータ生成で使っているパラメータの詳細はコードをご参照ください。
サンプルデータ生成のRコードとStanコードは以下の通りです。2変量なので、逆Wishart分布のパラメータには、自由度3と単位行列を指定しています。Stanを使った相関係数の推定については以前の記事をご参照ください。
library(tidyverse) library(cmdstanr) library(knitr) library(mvtnorm) library(doParallel) registerDoParallel(detectCores()) options(mc.cores = parallel::detectCores()) # 分散共分散行列を作成する関数 make_cov_matrix <- function(sig, rho) { matrix( c(sig[1]^2, sig[1] * sig[2] * rho, sig[1] * sig[2] * rho, sig[2]^2), nrow = 2) } # サンプルデータ生成 set.seed(1) N <- 100 mu <- c(2, 5) # 平均 rho <- 0.7 # 相関係数 sig_l <- c(1, 1) # 標準偏差(大) sig_m <- 0.1 * sig_l # 標準偏差(中) sig_s <- 0.1 * sig_m # 標準偏差(小)) Sig_l <- make_cov_matrix(sig_l, rho) Sig_m <- make_cov_matrix(sig_m, rho) Sig_s <- make_cov_matrix(sig_s, rho) dat_l <- rmvnorm(n = N, mean = mu, sigma = Sig_l) %>% as_tibble(.name_repair = "minimal") %>% setNames(c("x1", "x2")) dat_m <- rmvnorm(n = N, mean = mu, sigma = Sig_m) %>% as_tibble(.name_repair = "minimal") %>% setNames(c("x1", "x2")) dat_s <- rmvnorm(n = N, mean = mu, sigma = Sig_s) %>% as_tibble(.name_repair = "minimal") %>% setNames(c("x1", "x2")) # Stan mod_inv_wishart <- cmdstan_model("inv_wishart.stan")
data{ int<lower=0> N; int<lower=0> K; array[N] vector[K] x; } parameters { vector[K] mu; array[K] real<lower=0> sig; // 標準偏差 real<lower=-1, upper=1> rho; // 相関係数 } transformed parameters { cov_matrix[K] Sig = [[sig[1]^2, rho * sig[1] * sig[2]], [rho * sig[1] * sig[2], sig[2]^2]]; } model { x ~ multi_normal(mu, Sig); Sig ~ inv_wishart(K + 1, identity_matrix(K)); }
まず、標準偏差1のケースから見てみましょう。
# 分散(大) fit_l <- mod_inv_wishart$sample( data = list(N = nrow(dat_l), K = ncol(dat_l), x = dat_l %>% dplyr::select(x1, x2)) ) fit_l$summary() %>% mutate_at(vars(-variable), ~ round(., digits = 2)) %>% kable() %>% print() fit_m <- mod_inv_wishart$sample( data = list(N = nrow(dat_m), K = ncol(dat_m), x = dat_m %>% dplyr::select(x1, x2)) )
推定結果は以下の通りです。rhoが相関係数、sigが標準偏差です。相関係数はおよそ0.69なので、問題なく推定されていることがわかります。
Table: 標準偏差1
variable | mean | median | sd | mad | q5 | q95 | rhat | ess_bulk | ess_tail |
---|---|---|---|---|---|---|---|---|---|
lp__ | -54.28 | -53.95 | 1.60 | 1.45 | -57.25 | -52.34 | 1 | 1980.33 | 2749.78 |
mu[1] | 2.07 | 2.07 | 0.09 | 0.09 | 1.92 | 2.22 | 1 | 2458.65 | 2255.20 |
mu[2] | 5.03 | 5.03 | 0.09 | 0.09 | 4.88 | 5.19 | 1 | 2483.63 | 2611.69 |
sig[1] | 0.91 | 0.91 | 0.06 | 0.06 | 0.82 | 1.02 | 1 | 2740.19 | 2875.48 |
sig[2] | 0.93 | 0.93 | 0.07 | 0.06 | 0.83 | 1.04 | 1 | 2766.57 | 2775.84 |
rho | 0.69 | 0.69 | 0.05 | 0.05 | 0.60 | 0.77 | 1 | 3066.16 | 2939.89 |
Sig[1,1] | 0.84 | 0.83 | 0.12 | 0.11 | 0.67 | 1.05 | 1 | 2740.17 | 2875.48 |
Sig[2,1] | 0.59 | 0.58 | 0.10 | 0.10 | 0.44 | 0.77 | 1 | 2305.56 | 2359.92 |
Sig[1,2] | 0.59 | 0.58 | 0.10 | 0.10 | 0.44 | 0.77 | 1 | 2305.56 | 2359.92 |
Sig[2,2] | 0.87 | 0.86 | 0.12 | 0.12 | 0.69 | 1.09 | 1 | 2766.58 | 2775.84 |
次に、標準偏差0.1と0.01のケースを見てみましょう。Rコードは標準偏差1のケースとほぼ同じなので省略します。
Table: 標準偏差0.1
variable | mean | median | sd | mad | q5 | q95 | rhat | ess_bulk | ess_tail |
---|---|---|---|---|---|---|---|---|---|
lp__ | 314.62 | 314.94 | 1.57 | 1.44 | 311.65 | 316.55 | 1 | 1999.06 | 2916.11 |
mu[1] | 2.00 | 2.00 | 0.01 | 0.01 | 1.98 | 2.02 | 1 | 3321.52 | 2826.60 |
mu[2] | 5.01 | 5.01 | 0.01 | 0.01 | 4.99 | 5.03 | 1 | 4217.46 | 3483.05 |
sig[1] | 0.14 | 0.14 | 0.01 | 0.01 | 0.13 | 0.16 | 1 | 3675.98 | 2510.00 |
sig[2] | 0.14 | 0.14 | 0.01 | 0.01 | 0.12 | 0.16 | 1 | 3945.63 | 2784.80 |
rho | 0.34 | 0.34 | 0.09 | 0.09 | 0.19 | 0.47 | 1 | 3772.17 | 2781.13 |
Sig[1,1] | 0.02 | 0.02 | 0.00 | 0.00 | 0.02 | 0.03 | 1 | 3675.98 | 2510.00 |
Sig[2,1] | 0.01 | 0.01 | 0.00 | 0.00 | 0.00 | 0.01 | 1 | 3103.56 | 2274.59 |
Sig[1,2] | 0.01 | 0.01 | 0.00 | 0.00 | 0.00 | 0.01 | 1 | 3103.56 | 2274.59 |
Sig[2,2] | 0.02 | 0.02 | 0.00 | 0.00 | 0.02 | 0.02 | 1 | 3945.62 | 2784.80 |
Table: 標準偏差0.01
variable | mean | median | sd | mad | q5 | q95 | rhat | ess_bulk | ess_tail |
---|---|---|---|---|---|---|---|---|---|
lp__ | 379.40 | 379.73 | 1.58 | 1.39 | 376.41 | 381.30 | 1 | 2114.11 | 2354.87 |
mu[1] | 2.00 | 2.00 | 0.01 | 0.01 | 1.98 | 2.02 | 1 | 5255.66 | 3140.49 |
mu[2] | 5.00 | 5.00 | 0.01 | 0.01 | 4.98 | 5.02 | 1 | 5202.76 | 3168.72 |
sig[1] | 0.10 | 0.10 | 0.01 | 0.01 | 0.09 | 0.11 | 1 | 4796.71 | 2767.79 |
sig[2] | 0.10 | 0.10 | 0.01 | 0.01 | 0.09 | 0.11 | 1 | 4791.14 | 3021.03 |
rho | 0.01 | 0.01 | 0.10 | 0.10 | -0.15 | 0.16 | 1 | 4468.67 | 2851.13 |
Sig[1,1] | 0.01 | 0.01 | 0.00 | 0.00 | 0.01 | 0.01 | 1 | 4796.68 | 2767.79 |
Sig[2,1] | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 1 | 4098.84 | 2809.06 |
Sig[1,2] | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 1 | 4098.84 | 2809.06 |
Sig[2,2] | 0.01 | 0.01 | 0.00 | 0.00 | 0.01 | 0.01 | 1 | 4791.10 | 3021.03 |
サンプルデータの相関係数は0.7に固定されていますが、サンプルデータの標準偏差が小さくなるにつれて、推定結果の相関係数rhoは0.7より小さくなり、標準偏差sigはサンプルデータの標準偏差より大きく推定されていることが確認できます。つまり、推定対象の標準偏差が小さいと、$\hat{R}$や有効サンプルサイズは問題なさそうに見えても、誤った推定結果になっていることがわかります。推定対象の標準偏差が小さい可能性がある場合は、逆Wishart分布を事前分布として利用するのは避けた方がよいと考えられます。
まとめ
今回は、推定対象の分散が小さいときに、逆Wishart分布を事前分布として分散共分散行列を推定するときの問題について紹介しました。推定対象の分散が小さくなるにつれて、推定結果にバイアスが生じ、相関は無相関に近づき標準偏差は過大に推定されることを確認しました。よく使われる事前分布であっても、推定対象によっては必ずしも適切ではないことがあるので注意が必要です。分散共分散行列の推定に推奨される事前分布については別記事にて紹介します。
参考
I. Alvarez, J. Niemi, and M. Simpson, “Bayesian inference for a covariance matrix.” arXiv, Jul. 08, 2016.