LIVESENSE Data Analytics Blog

リブセンスのデータ分析、機械学習、分析基盤に関する取り組みをご紹介するブログです。

LKJ相関分布を使った分散共分散行列のベイズ推定

 こんにちは、リブセンスでデータサイエンティストをしている北原です。今回も分散共分散行列のベイズ推定を扱います。以前の記事で分散共分散行列の事前分布に逆Wishart分布を使うと、推定対象の分散が小さいときに推定バイアスが生じることを紹介しました。では、どのような事前分布を使ったらよいかというのが今回の内容です。記事タイトルからも推測できるように、分散共分散行列を分散の対角行列と相関行列に分解して、相関行列の事前分布にLKJ(Lewandowski-Kurowicka-Joe)相関分布を使うとよいです。これはStanで推奨されている方法でもあります。利用言語はRとStanです。なお、本記事は以下の過去記事をふまえた内容となっています。

analytics.livesense.co.jp

analytics.livesense.co.jp

LKJ相関分布

 LKJ相関分布(Lewandowskiet al. (2008))は、$\mathbf{R}$を相関行列、$c_d$を正規化定数とするとき $$ \begin{align} P(\mathbf{R}) &= c_d\det \left(\mathbf{R} \right)^{\eta - 1} \end{align} $$ となります。なお、相関行列は、対角成分が1で、$i$行$j$列の非対角成分が相関係数$\rho_{i,j}$となる正定値行列です。

 LKJ相関分布は、分散共分散行列を分散の対角行列と相関行列に分解して、相関行列に対する事前分布として使います。適当に乱数を発生させるだけでは相関行列にはならないし、相関行列の条件を満たしていてもサンプリングに時間がかかるようだと実用的に使えません。LKJ相関分布は効率的なランダム相関行列を生成するために提案されたものなので効率的な計算が可能です(Lewandowskiet al. (2008))。

 相関の強さはパラメータ$\eta$を使って制御します。

  • $\eta = 1$のとき、一様分布(無情報)
  • $\eta > 1$のとき、値が大きいほど0周辺の確率が高くなる(弱い相関)
  • $\eta < 1$のとき、値が小さいほど$\pm1$ 近辺の確率が高くなる(強い相関)

簡単な例として相関係数$\rho$ の2変量相関分布を考えると、理解しやすいかもしれません。行列式は$(1-\rho^2)^{\eta - 1}$なので、$\eta=1$なら定数、$\eta > 1$なら単峰で0に近づくほど高くなり、$\eta < 1$なら0が一番低く$\pm1$ に近づくほど高くなることがわかります。

LKJ相関分布を使うとよい理由

 分散共分散行列を分散の対角行列と相関行列とに分解し、相関行列の事前分布にLKJ相関分布を使う利点は次のとおりです。

  • 事後分布に(逆Wishart分布を使ったときのような)バイアスが生じにくい
  • 事前知識を組み込みやすい
  • 計算が比較的高速

 逆Wishart分布を使うと標準偏差が小さいとバイアスが生じるので注意が必要でしたが、LKJ相関分布を使う場合はその問題が緩和されるので比較的気軽に利用できます。逆Wishart分布のように分散が小さい確率が極端に低くなるわけではないため、逆Wishart分布を事前分布としたときのようなバイアスが生じにくいです。また、相関と分散を個別にモデル化しているため、相関と分散の依存性も生じにくいです。

 分散と相関それぞれに事前分布を設定することで、事前知識も組み込みやすくなっています。分散だけでなく、相関の強さも分散とは個別にパラメータで制御できます。そのため、相関の強弱をモデル化しやすくなっています。

 多変量になると計算時間がボトルネックになってきますが、LKJ相関分布は比較的高速なサンプリングが可能です。特に、Stanではコレスキー分解と合わせて使うことが推奨されています。

LKJ相関分布を利用した相関係数のベイズ推定

 それでは、実際にLKJ相関分布を使って相関係数を推定してみましょう。LKJ相関分布をそのまま利用した推定と、以前の記事で紹介したコレスキー分解を使った推定の二つを行います。コレスキー分解を使った推定では、逆Wishart分布を使ったときではバイアスが生じていたデータに対して、LKJ相関分布を利用すると適切な推定ができることを確認します。ここでは$\eta=1$として一様分布を指定します。

 以前の記事の再掲になりますがサンプルデータを作成します。

library(tidyverse)
library(cmdstanr)
library(knitr)

library(mvtnorm)

library(doParallel)
registerDoParallel(detectCores())
options(mc.cores = parallel::detectCores())

# 分散共分散行列を作成する関数
make_cov_matrix <- function(sig, rho) {
  matrix(
    c(sig[1]^2,              sig[1] * sig[2] * rho,
      sig[1] * sig[2] * rho, sig[2]^2),
    nrow = 2)       
}

# サンプルデータ生成
set.seed(1)
N <- 100
mu <- c(2, 5)           # 平均
rho <- 0.7              # 相関係数
sig_l <- c(1, 1)        # 標準偏差(大)
sig_m <- 0.1 * sig_l    # 標準偏差(中)
sig_s <- 0.1 * sig_m    # 標準偏差(小))
Sig_l <- make_cov_matrix(sig_l, rho)
Sig_m <- make_cov_matrix(sig_m, rho)
Sig_s <- make_cov_matrix(sig_s, rho)

dat_l <- 
  rmvnorm(n = N, mean = mu, sigma = Sig_l) %>% 
  as_tibble(.name_repair = "minimal") %>% 
  setNames(c("x1", "x2"))
dat_m <- 
  rmvnorm(n = N, mean = mu, sigma = Sig_m) %>% 
  as_tibble(.name_repair = "minimal") %>% 
  setNames(c("x1", "x2"))
dat_s <- 
  rmvnorm(n = N, mean = mu, sigma = Sig_s) %>% 
  as_tibble(.name_repair = "minimal") %>% 
  setNames(c("x1", "x2"))

 まず、LKJ相関分布をそのまま利用して、標準偏差が比較的大きいデータの相関係数を推定しましょう。RとStanのコードは次の通りです。

# コレスキー分解を使わずに相関係数を推定
mod_lkj <- cmdstan_model("lkj.stan")
fit <- mod_lkj$sample(
  data = list(N = nrow(dat_s), 
              K = ncol(dat_s),
              x = dat_s %>% dplyr::select(x1, x2))
)
fit$summary() %>% 
  mutate_at(vars(-variable), ~ round(., digits = 2)) %>% 
  kable() %>% 
  print()
# lkj.stan
data{
  int<lower=0> N;
  int<lower=0> K;
  array[N] vector[K] x;
}
parameters {
  vector[K] mu;
  corr_matrix[K] R;
  vector<lower=0>[K] sigma;
}
model {
  x ~ multi_normal(mu, quad_form_diag(R, sigma));
  sigma ~ cauchy(0, 5); # 分散
  R ~ lkj_corr(1);      # LKJ相関分布
}

 推定結果は次の通りです。当然ながら問題なく推定できていることがわかります。

variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -54.38 -54.05 1.63 1.44 -57.51 -52.42 1 1759.32 2390.98
mu[1] 2.07 2.07 0.09 0.09 1.91 2.22 1 2191.51 2877.01
mu[2] 5.03 5.03 0.10 0.10 4.87 5.18 1 2201.72 2704.12
R[1,1] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
R[2,1] 0.69 0.70 0.05 0.05 0.60 0.78 1 2661.22 2989.59
R[1,2] 0.69 0.70 0.05 0.05 0.60 0.78 1 2661.22 2989.59
R[2,2] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
sigma[1] 0.94 0.93 0.07 0.07 0.83 1.05 1 2870.02 2882.43
sigma[2] 0.96 0.95 0.07 0.07 0.85 1.08 1 2836.73 2682.16

 次に、コレスキー分解を利用して推定します。Stanのコードは次の通りです。Stanの呼び出しコードはコレスキー分解を使わないときとほぼ同じなので省略します。Stanでは多変量正規分布もLKJ相関分布もコレスキー因子に対応したものが用意されているので、それらを使います。

data{
  int<lower=0> N;
  int<lower=0> K;
  array[N] vector[K] x;
}
parameters {
  vector[K] mu;
  cholesky_factor_corr[K] L;                 # コレスキー因子
  vector<lower=0>[K] sig;
}
transformed parameters {
  corr_matrix[K] R;                          # 相関
  R = multiply_lower_tri_self_transpose(L);  # R = L * L'
}
model {
  x ~ multi_normal_cholesky(mu, diag_matrix(sig) * L);
  sig ~ cauchy(0, 5);                        # 分散
  L ~ lkj_corr_cholesky(1);                  # コレスキー因子のLKJ相関分布
}

 推定結果は次の通りです。サンプルデータの標準偏差が、大、中、小の順に示します。

Table: 標準偏差1

variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -54.30 -54.04 1.55 1.40 -57.27 -52.37 1 2037.60 2636.02
mu[1] 2.07 2.07 0.09 0.09 1.91 2.22 1 2674.84 2458.21
mu[2] 5.03 5.03 0.10 0.09 4.87 5.19 1 2643.38 2674.78
L[1,1] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
L[2,1] 0.69 0.70 0.05 0.05 0.60 0.77 1 2799.36 2753.96
L[1,2] 0.00 0.00 0.00 0.00 0.00 0.00 NA NA NA
L[2,2] 0.72 0.72 0.05 0.05 0.63 0.80 1 2799.38 2753.96
sig[1] 0.93 0.93 0.07 0.07 0.83 1.05 1 2691.41 2821.48
sig[2] 0.95 0.95 0.07 0.07 0.85 1.07 1 2854.97 2666.02
R[1,1] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
R[2,1] 0.69 0.70 0.05 0.05 0.60 0.77 1 2799.36 2753.96
R[1,2] 0.69 0.70 0.05 0.05 0.60 0.77 1 2799.36 2753.96
R[2,2] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA

Table: 標準偏差0.1

variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ 385.49 385.80 1.61 1.44 382.38 387.46 1 1693.41 2359.71
mu[1] 2.00 2.00 0.01 0.01 1.98 2.02 1 3225.03 2816.02
mu[2] 5.01 5.01 0.01 0.01 4.99 5.03 1 3105.07 2523.10
L[1,1] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
L[2,1] 0.68 0.68 0.06 0.06 0.58 0.76 1 3266.98 3002.45
L[1,2] 0.00 0.00 0.00 0.00 0.00 0.00 NA NA NA
L[2,2] 0.73 0.73 0.05 0.05 0.65 0.82 1 3266.93 3002.45
sig[1] 0.10 0.10 0.01 0.01 0.09 0.12 1 3171.79 2624.12
sig[2] 0.10 0.10 0.01 0.01 0.09 0.11 1 2941.47 2524.74
R[1,1] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
R[2,1] 0.68 0.68 0.06 0.06 0.58 0.76 1 3266.98 3002.45
R[1,2] 0.68 0.68 0.06 0.06 0.58 0.76 1 3266.98 3002.45
R[2,2] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA

Table: 標準偏差0.01

variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ 830.49 830.86 1.66 1.42 827.30 832.47 1 1465.93 1766.83
mu[1] 2.00 2.00 0.00 0.00 2.00 2.00 1 5183.15 3146.45
mu[2] 5.00 5.00 0.00 0.00 5.00 5.00 1 4924.74 3179.99
L[1,1] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
L[2,1] 0.64 0.65 0.06 0.06 0.54 0.73 1 1637.86 2139.46
L[1,2] 0.00 0.00 0.00 0.00 0.00 0.00 NA NA NA
L[2,2] 0.76 0.76 0.05 0.05 0.68 0.84 1 1637.86 2139.46
sig[1] 0.01 0.01 0.00 0.00 0.01 0.01 1 1913.54 2153.53
sig[2] 0.01 0.01 0.00 0.00 0.01 0.01 1 1640.22 1958.51
R[1,1] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA
R[2,1] 0.64 0.65 0.06 0.06 0.54 0.73 1 1637.86 2139.46
R[1,2] 0.64 0.65 0.06 0.06 0.54 0.73 1 1637.86 2139.46
R[2,2] 1.00 1.00 0.00 0.00 1.00 1.00 NA NA NA

 以前の記事と同じ条件でデータを生成しているので、逆Wishart分布を使ったときの推定結果と比較したものを次に示します。サンプルデータの相関係数は0.7なので、推定結果が0.7に近ければ適切な推定ができていると判断できます。LKJ相関分布を利用しても、標準偏差が小さくなるとわずかに相関係数が小さく推定されていることがわかります。しかし、逆Wishart分布を利用したときほどではないことが確認できます。なお、標準偏差が小さくなることで生じるわずかなバイアスについては、逆Wishart分布を改善した手法でも生じることが知られています(Alvarez et al. (2014))。

Table: 逆Wishart分布とLKJ相関分布の相関係数推定の比較

サンプルデータの標準偏差 逆Wishart LKJ相関
1 0.69 0.69
0.1 0.34 0.68
0.01 0.01 0.64

まとめ

 今回は、LKJ相関分布を利用した分散共分散行列の推定について紹介しました。分散共分散行列を分散の対角行列と相関行列に分解し、相関行列の事前分布にLKJ相関分布を使うことで、逆Wishart分布を事前分布として利用するときより、バイアスが生じにくく、事前知識も組み込みやすくなることを示しました。結局のところ、Stanで推奨されている方法を使うのがよいということではありますが、LKJ相関分布を使うメリットを把握することでよりよいモデリングが可能になるのではないかと思います。

参考

D. Lewandowski, D. Kurowicka, and H. Joe, “Generating random correlation matrices based on vines and extended onion method,” Journal of Multivariate Analysis, vol. 100, no. 9, pp. 1989–2001, Oct. 2009.

I. Alvarez, J. Niemi, and M. Simpson, “Bayesian inference for a covariance matrix.” arXiv, Jul. 08, 2016.