LIVESENSE Data Analytics Blog

リブセンスのデータ分析、機械学習、分析基盤に関する取り組みをご紹介するブログです。

Stanによるレコメンデーション用Factorization Machinesの実装

 こんにちは、リブセンスでデータサイエンティストをしている北原です。今回はStanを使ったレコメンデーション用FM(Factorization Machines)を扱います。

 FMはシンプルなモデルなのでStanで簡単に実装することができます。しかし、レコメンデーションで使う場合はスパースデータに対応したものにしないと無駄な計算が多く計算に非常に時間がかかってしまったりメモリを大量消費して計算できなくなったりします。そこで、今回はスパースデータに対応したFMのStan実装について紹介します。

なぜレコメンデーションモデルの実装にStanを使うのか

 まず、Stanでレコメンデーション用FMを実装できると何がうれしいかについて考えましょう。本ブログではFMの実装にはJuliaを使ってきたので、Juliaなどの言語でフルスクラッチ実装するときと比較して説明します。

 実務で機械学習を使う場合、課題に合いそうな既存アルゴリズムやモデルを探してきて適用する方法と、課題に合うように既存アルゴリズムやモデルをカスタマイズする方法があります。どちらが優れているというものではなく、状況に合わせて使い分けます。比較的シンプルな課題に素早く対応する場合は前者、複雑かつ重要な課題に時間をかけて対応する場合は後者を採用します。実際には、初期段階では前者を採用することが多く、それだけでKPIを大幅に向上させられるのが理想です。しかし、弊社の場合そのようなケースは稀で、課題構造が明らかになるにつれて後者に移行していくケースが多いです。このことはレコメンデーションでも同様です。そのため、課題に合わせたカスタマイズを効率化できるとありがたいわけです。

 課題に合わせたカスタマイズでは、試行錯誤的にモデル構造を変更します。そうすると、モデル構造は徐々に複雑なものになっていくことが多くなります。このようなケースでも、JuliaやPythonでモデルのカスタマイズはできるし、実際に行ってきました(例えば、こちら)。しかし、モデル構造が複雑になってくると実装の手間は増えます。そして、複雑になるほど、計算式導出時の計算ミスや実装ミスも増えやすくなり開発効率が低下します。

 Stanはこのような課題に合わせたカスタマイズをするときに有用です。Stanのようにモデル構造を記述するだけでよい言語であれば、モデル構造が複雑になってもフルスクラッチで実装するときほど実装の手間は増えません。そのため、Stanを使うことで様々なモデルの検証を効率的に行うことができます。なお、今回の目的の範囲で言えば、Stan以外のTensorFlow ProbabilityやPyroを使うのでもよいと思います。どのツールも一長一短ありますが、本ブログでStanを使っているのは筆者がStanに慣れているからというのが一番の理由です。少ない手間でサクサク使うためには、慣れているものを使うのがよいですからね。

 レコメンデーション用FMを使う場合スパースデータを扱う必要があるので、Stanコードもスパースデータに対応させる必要があります。レコメンデーション用FMでは、ユーザーIDやアイテムIDごとに特徴量を割り当てたデータを入力として学習するため、非常にスパースなデータを効率的に処理することがもとめられます。現時点(2021年末)ではStanによるスパースデータのサポート状況はあまり充実しておらず、使いやすいデータ構造や関数は少ないです。しかし、CSR(Compressed Sparse Row)形式などのスパースデータ用データ構造に変換されたデータを入力とし、少しコードを工夫するとスパースデータであっても効率的な処理が可能です。

 ただし、Stan実装だけですべての実運用ケースをカバーできるわけではありません。データ量が多いときにはスパースデータに対応させても計算量やメモリ使用量が問題になりやすいので、Stanで扱うのは難しくなります。そのため、Stan実装は、データ量が少ない新規サービスの初期段階や、一部のデータを使ったモデル改善での利用に限定されます。それでも、弊社では現在新規事業に力を入れていて、まだデータ量がそれほど多くないケースも多いので活用場面は割とあります。

シンプルなFMのStan実装

 まずデータがスパースなことを想定しないシンプルなFMをStanで実装してみましょう。

 Rのライブラリcmdstanrを利用します。stanのバージョンは2.28.2です。古いバージョンを使っている場合は、本記事で紹介するコードを実行するとエラーになることがあるのでご注意ください。

 ここではFMのモデルはモデルパラメータの事前分布に正規分布を仮定したシンプルなものを使います。特徴量数を$n$、因子数を$k$とし、評価値$y$、評価と対応するコンテキストデータ${\bf x} \in{\mathbb R}^{n}$、モデルパラメータ$w_0 \in {\mathbb R}$、${\bf w} \in {\mathbb R}^n$、${\bf V} \in {\mathbb R}^{n \times k}$を使って、以下のようなモデルを考えます。

$$ \begin{eqnarray} \hat{y}({\bf x}) &=& w_0 + \sum^n_{j=1}{w_j x_j} + \frac{1}{2} \sum_{f=1}^{k}\left(\left(\sum_{j=1}^{n} v_{j, f} x_{j}\right)^{2}-\sum_{j=1}^{n} v_{j, f}^{2} x_{j}^{2}\right) \label{eq:fm} \\ y &\sim& {\mathcal N}(\hat{y} , \sigma) \label{eq:y} \\ w_{i} &\sim& {\mathcal N}(\mu^{(w)}, \sigma^{(w)}) \label{eq:w_i} \\ v_{i, f} &\sim& {\mathcal N}(\mu^{(v)}, \sigma^{(v)}) \label{eq:v_ij} \\ \end{eqnarray} $$

 利用するライブラリを読み込み、FMのモデル式を参考に学習するサンプルデータを次のように作成します。データをスパースにするためにカラムごとに9割のデータに値0を入れています。

library(tidyverse)
library(cmdstanr)
library(randomForest)
library(SparseM)

# Stanで並列処理するための設定
library(doParallel)
registerDoParallel(detectCores())
options(mc.cores = parallel::detectCores())

# 学習サンプルデータ作成 -------------------------------------------------------------
set.seed(123)
N <- 50 # サンプル数
M <- 50 # 特徴量数
K <- 5  # 因子数

# モデルパラメータ
w0 <- 1
w <- rnorm(M, mean = 0, sd = 1)
V <- matrix(rnorm(M * K, mean = 0, 1), nrow = M)

# データ
X <- matrix(rnorm(N * M, 0, 1), nrow = N)
X_test <- matrix(rnorm(N * M, mean = 1, sd = 1), nrow = N)

# スパースにするため0を入れる
for (i in 1:M) {
  X[sample(1:N, N * 0.9), i] <- 0
}
y_ <- w0 + c(X %*% w) + 0.5 * rowSums((X %*% V)^2 - X^2 %*% V^2)
y <- y_ %>% rnorm(n = length(.), mean = ., sd = 0.1)

for (i in 1:M) {
  X_test[sample(1:N, N*0.9), i] <- 0
}
y_test_ <- w0 + c(X_test %*% w) + 0.5 * rowSums((X_test %*% V)^2 - X_test^2 %*% V^2)
y_test <- y_test_ %>% rnorm(n = length(.), mean = ., sd = 0.1)

 FMとの結果の比較用に、RandomForestで学習させてみましょう。ここでは評価基準にRMSE(Root Mean Squared Error)を使います。

# ランダムフォレスト ---------------------------------------------------------------
rf <- randomForest(y ~ ., data.frame(X))
y_est_rf <- predict(rf, data.frame(X))
y_pred_rf <- predict(rf, data.frame(X_test))
rmse_rf <- sqrt(mean((y_est_rf - y_)^2))
rmse_test_rf <- sqrt(mean((y_pred_rf - y_test_)^2))
print(rmse_rf)
print(rmse_test_rf)
> print(rmse_rf)
[1] 3.805909
> print(rmse_test_rf)
[1] 11.067

 ではFMの実装に移りましょう。式($\ref{eq:fm}$)〜($\ref{eq:v_ij}$)をそのまま表現すればよいので、Stanコードは次のように書けます。なお、rep_vector()関数は各行の和を計算するためのものです。

// 通常のFM
data {
  int<lower=0> N;
  int<lower=0> M;
  int<lower=0> K;
  matrix[N, M] X;
  vector[N] y;

  int<lower=0> N_test;
  matrix[N_test, M] X_test;
}
parameters {
  real w0;
  vector[M] w;
  matrix[M, K] V;
  real<lower=0> sig;
  real mu_w;
  real<lower=0> sig_w;
  real mu_v;
  real<lower=0> sig_v;
}
model {
  y ~ normal(w0 + X * w +  0.5 * ((X * V).^2 - (X.^2) * (V.^2)) * rep_vector(1.0, K), sig);
  
  w0 ~ normal(mu_w, sig_w);
  w ~ normal(mu_w, sig_w);
  to_vector(V) ~ normal(mu_v, sig_v);
}
generated quantities {
  vector[N] est;
  vector[N_test] pred;
  for(i in 1:N){
    est[i] = normal_rng(w0 + X[i,] * w + 0.5 * sum((X[i,] * V).^2 - (X[i,].^2) * (V.^2)), sig);
  }
  for (i in 1:N_test) {
    pred[i] = normal_rng(w0 + X_test[i,] * w + 0.5 * sum((X_test[i,] * V).^2 - (X_test[i,].^2) * (V.^2)), sig);
  }
}

 Stanの呼び出しRコードは以下のとおりです。

# 通常のFM -------------------------------------------------------------------
model_fm_dense <- cmdstan_model("fm_dense.stan")
fit_fm_dense <- model_fm_dense$sample(
  data = list(
    N = nrow(X),
    M = ncol(X),
    K = K,
    X = X,
    y = y,
    N_test = nrow(X_test),
    X_test = X_test
  ))
s <- fit_fm_dense$summary()
rmse_fm <-
  s %>% 
  filter(variable %>% str_detect("est")) %>% 
  mutate(y = y_) %>% 
  mutate(diff = mean - y) %>% 
  summarise(rmse_fm = sqrt(mean(diff^2)))
rmse_test_fm <-
  s %>% 
  filter(variable %>% str_detect("pred")) %>% 
  mutate(y = y_test_) %>% 
  mutate(diff = mean - y) %>% 
  summarise(rmse_test_fm = sqrt(mean(diff^2)))
print(rmse_fm)
print(rmse_test_fm)
> print(rmse_fm)
# A tibble: 1 × 1
  rmse_fm
    <dbl>
1   0.172
> print(rmse_test_fm)
# A tibble: 1 × 1
  rmse_test_fm
         <dbl>
1         10.5

 データ生成に使ったのと同じFMのモデル式が使われているためか、ランダムフォレストよりよい精度が得られていることがわかります。推定、予測ともうまくいっているようです。

スパースデータ対応FMのStan実装

 それではスパースデータに対応したFMを実装しましょう。学習データはシンプルなFMと同じものを使います。

 Stanコードは以下のようになります。

data {
  int N;
  int M;
  int K;
  int N_ra;
  vector[N_ra] RA;                    // 値
  int<lower=1, upper=N_ra+1> IA[N+1]; // 行範囲を示すための列インデックスへのポインタ
  int<lower=1, upper=M> JA[N_ra];     // 列インデックス
  vector[N] y;
  
  int N_test;
  int N_ra_test;
  vector[N_ra_test] RA_test;
  int<lower=1, upper=N_ra_test+1> IA_test[N_test+1];
  int<lower=1, upper=M> JA_test[N_ra_test];
}

parameters {
  real w0;
  vector[M] w;
  matrix[M, K] V;
  real<lower=0> sig;
  real mu_w;
  real<lower=0> sig_w;
  real mu_v;
  real<lower=0> sig_v;
}
model {
  for (i in 1:N) {
    row_vector[IA[i+1] - IA[i]] x;
    vector[IA[i+1] - IA[i]] w_;
    matrix[IA[i+1] - IA[i], K] V_;
    x = RA[IA[i]:(IA[i+1]-1)]';      // 非ゼロ成分のみの変数
    w_ = w[JA[IA[i]:(IA[i+1]-1)]];   // 非ゼロ成分のみの変数
    V_ = V[JA[IA[i]:(IA[i+1]-1)],:]; // 非ゼロ成分のみの変数
    
    y[i] ~ normal(w0 + x * w_ + 0.5 * sum((x * V_).^2 - (x.^2) * (V_.^2)), sig);
  }

  w0 ~ normal(mu_w, sig_w);
  w ~ normal(mu_w, sig_w);
  to_vector(V) ~ normal(mu_v, sig_v);
}

generated quantities {
  vector[N] est;
  vector[N_test] pred;
  for (i in 1:N){
    row_vector[IA[i+1] - IA[i]] x;
    vector[IA[i+1] - IA[i]] w_;
    matrix[IA[i+1] - IA[i], K] V_;
    x = RA[IA[i]:(IA[i+1]-1)]';
    w_ = w[JA[IA[i]:(IA[i+1]-1)]];
    V_ = V[JA[IA[i]:(IA[i+1]-1)],:];
    
    est[i] = normal_rng(w0 + x * w_ + 0.5 * sum((x * V_) .^2 - (x.^2) * (V_.^2)), sig);
  }
  for (i in 1:N_test) {
    row_vector[IA_test[i+1] - IA_test[i]] x;
    vector[IA_test[i+1] - IA_test[i]] w_;
    matrix[IA_test[i+1] - IA_test[i], K] V_;
    x = RA_test[IA_test[i]:(IA_test[i+1]-1)]';
    w_ = w[JA_test[IA_test[i]:(IA_test[i+1]-1)]];
    V_ = V[JA_test[IA_test[i]:(IA_test[i+1]-1)],:];
    
    pred[i] = normal_rng(w0 + x * w_ + 0.5 * sum((x * V_).^2 - (x.^2) * (V_.^2)), sig);
  }
}

 シンプルなFMとの違いは、CSR形式のデータを入力にしているところです。CSRは値と列のインデックス、列のインデックスが何行目かを示すポインタの3種類の配列でスパースデータを効率的に表現したデータ構造です。ここで示したStanコードではRAが値、JAが列のインデックス、IAがポインタです。CSRなどのスパースデータ構造の詳細については、検索するとWikiqiitaの記事などでわかりやすい解説が多数見つかるのでそちらをご参照ください。

 スパースデータに対応したFMとシンプルなFMとでは基本的なモデル式は変わりませんが、スパースデータに対応したFMでは非ゼロデータのみ計算させるための変数が導入されているところが異なっています。インデックスの部分は一見複雑そうに見えますが、CSRのデータ構造が頭に入っていれば理解するのは容易です。なお、インデックスは0ではなく1から始まることを仮定しています。推定対象のモデルパラメータはシンプルなFMでもスパースデータ対応のFMでも同じで、事前分布の設定も同一です。

 Stanの呼び出しRコードは以下のとおりです。CSR形式のデータに変換するのにSparseMライブラリのas.matrix.csr()関数を使っています。この関数で変換するとインデックスが0ではなく1から始まるので、Stanコードでも1から始まることを前提としたコードになっています。

# スパースデータ対応のFM ------------------------------------------------------------
X_csr <- as.matrix.csr(X)
X_test_csr <- as.matrix.csr(X_test)

model_fm_sparse <- cmdstan_model("fm_sparse.stan")
fit_fm_sparse <- model_fm_sparse$sample(
  data = list(
    N = nrow(X_csr),
    M = ncol(X_csr),
    K = K,
    N_ra = length(X_csr@ra),
    RA = X_csr@ra,
    IA = X_csr@ia,
    JA = X_csr@ja,
    y = y,
    N_test = nrow(X_test_csr),
    N_ra_test = length(X_test_csr@ra),
    RA_test = X_test_csr@ra,
    IA_test = X_test_csr@ia,
    JA_test = X_test_csr@ja
  ))
s <- fit_fm_sparse$summary()
rmse_fm <-
  s %>% 
  filter(variable %>% str_detect("est")) %>% 
  mutate(y = y_) %>% 
  mutate(diff = mean - y) %>% 
  summarise(rmse_fm = sqrt(mean(diff^2)))
rmse_test_fm <-
  s %>% 
  filter(variable %>% str_detect("pred")) %>% 
  mutate(y = y_test_) %>% 
  mutate(diff = mean - y) %>% 
  summarise(rmse_test_fm = sqrt(mean(diff^2)))
print(rmse_fm)
print(rmse_test_fm)
> print(rmse_fm)
# A tibble: 1 × 1
  rmse_fm
    <dbl>
1   0.151
> print(rmse_test_fm)
# A tibble: 1 × 1
  rmse_test_fm
         <dbl>
1         10.5

サンプリング過程が異なるのでシンプルなFMと全く同じ結果にはならないものの、問題なく推定や予測ができていることがわかります。

 レコメンデーションのように極端に非ゼロ要素が少ない場合はスパースデータ対応のFMのほうが優れています。しかし、今回の学習用サンプルデータのようにデータ量が少なかったり非ゼロ要素の割合がそれほど少なくない場合はスパースデータ対応のFMよりシンプルなFMのほうがメモリ使用量が少なく計算時間も短いことが多いです。目的に応じて適宜使い分けたほうがよいでしょう。

 また、実行してみるとわかりますが、StanでFMのパラメータ推定を行うと警告が発生することが多いです。この問題への対処法については後日別記事で扱おうと思います。

まとめ

 今回はスパースデータに対応したFMのStan実装を紹介しました。レコメンデーション用FMをStanで実装できると試行錯誤的なモデル改善を効率的に行うことができます。しかし、レコメンデーション用FMを使う場合スパースデータを効率的に処理することが必要になります。Stanでも入力データやコードを工夫することでスパースデータを効率的に処理することが可能です。そこで、今回はFMをスパースデータに対応させる方法を紹介しました。