概要
GBDT(Gradient Boosting Decision Tree; 勾配ブースティング決定木) は,決定木によるアンサンブル学習の一種.Kaggle で頻用されている.
GBDT は決定木の列の推定を統合(boosting)することで高精度な予測を行う.決定木の列は,既存の列のアンサンブルから生じる誤差を予測する決定木を新たに生成するということを繰り返して,逐次的に生成される.
資料
-
Gradient Boosting Interactive Playground
- トイデータで感じをつかむ
GBDT をちゃんと理解しようと思ったきっかけ
現実は、「特徴量エンジニアリングを含めた前処理がもっとも重要で、モデルはデフォルトのGBDT使っとけばOK」っていうものだから、Feature Store的な方面は良いけど、頑張ってベイズ最適化でモデルをチューニングするAutoML的方面ははっきり言って殆ど意味がない。ついでに論文も再現性がない。
— まますさん (@mamas16k) March 27, 2022
やっぱりGBDTで十分ですよね。前処理で勝負が着いているというのは、Kaggleでの経験はないですが仕事で数回テーブルデータを扱った感触ですごくよくわかります。データ元のビジネスサイドの人が何が効きそうか一番良く知ってる。それをきれいに再現するだけでうまく行ったりする https://t.co/8MFkIXm1c4
— zakkini (@yoshimasaizaki) March 27, 2022
内容
$m$ 次元特徴量データ $\mathbf{x}_i\in\mathbb{R}^m$ とそのラベル $y_i\in\mathbb{R}$ の組の集合 $\mathcal{D}={(\mathbf{x}_i, y_i)}$ を訓練データセットとする.サンプル数を $n=|\mathcal{D}|$ とする.
GBDT を $\phi:\mathbb{R}^m\rightarrow \mathbb{R}$ とおき,$K$ 個の決定木からなるとする. $$ \hat{y}_{i}=\phi\left(\mathbf{x}_{i}\right)=\sum_{k=1}^{K} f_{k}\left(\mathbf{x}_{i}\right), \quad f_{k} \in \mathcal{F}. $$ $\mathcal{F}$ はとりうる弱学習器全体の集合であり(ここでは決定木),$\mathcal{F}=\left\{f(\mathbf{x})=w_{q(\mathbf{x})}\right\}$ と書ける.$q$ は決定木の葉の index を返す関数.
まとめると,入力 $\mathbf{x}$ から各決定木の index $q(\mathbf{x})$ が決まり,決定木ごとの重み $w$ が定まる.この総和が出力 $\hat{y}$ であり,学習時にはラベル $y$ との誤差が測られる.
損失関数と勾配
$\phi$ の損失関数 $\mathcal{L}(\phi)$ を以下で定義する. $$ \mathcal{L}(\phi)=\sum_{i} l\left(\hat{y}_{i}, y_{i}\right)+\sum_{k} \Omega\left(f_{k}\right) $$ $l$ は微分可能な凸関数.$\Omega$ は木構造が複雑になることにペナルティを与える正則化項で,次のように定義する. $$ \begin{aligned} \Omega(f) &=\gamma T+\frac{1}{2} \lambda|w|^{2} \\ &=\gamma T+\frac{1}{2} \lambda\sum_{j=1}^Tw_j^2 \end{aligned} $$ ここで $T$ は木 $f$ の葉の数.
勾配木ブースティング
概要で述べたように,GBDT $\phi$ の決定木は逐次的に生成される.
いま決定木 $f_1,\ldots,f_{t-1}$ が生成されており,新たに $f_t$ を生成するとする.生成ずみの決定木による出力を $\hat{y}^{(t-1)}$ とおくと,このときの損失関数 $\mathcal{L}^{(t)}$ は $$ \mathcal{L}^{(t)}=\sum_{i=1}^{n} l\left(y_{i}, \hat{y}_{i}^{(t-1)}+f_{t}\left(\mathbf{x}_{i}\right)\right)+\Omega\left(f_{t}\right). $$ ただし過去に生成した決定木のパラメータは変えないものとする.すなわち最適化の対象は $f_t$ のみである.
ここで $f_{t}(\mathbf{x}_{i})$ を差分と見なし,関数 $l(y_i,\:\cdot\:)$ を $\hat{y}_{i}^{(t-1)}$ について $2$ 次までのテイラー展開で近似することを考える.
$g_{i}:=\partial_{\hat{y}^{(t-1)}} l\left(y_{i}, \hat{y}^{(t-1)}\right),h_{i}:=\partial_{\hat{y}^{(t-1)}}^{2} l\left(y_{i}, \hat{y}^{(t-1)}\right)$ とおくと, $$ \mathcal{L}^{(t)} \simeq \sum_{i=1}^{n}\left[l\left(y_{i}, \hat{y}^{(t-1)}\right)+g_{i} f_{t}\left(\mathbf{x}_{i}\right)+\frac{1}{2} h_{i} f_{t}^{2}\left(\mathbf{x}_{i}\right)\right]+\Omega\left(f_{t}\right). $$ $l(y_{i}, \hat{y}^{(t-1)})$ は関係ないので除去して $\tilde{\mathcal{L}}^{(t)}$ とおく. $$ \begin{aligned} \tilde{\mathcal{L}}^{(t)} &=\sum_{i=1}^{n}\left[g_{i} f_{t}\left(\mathbf{x}_{i}\right)+\frac{1}{2} h_{i} f_{t}^{2}\left(\mathbf{x}_{i}\right)\right]+\Omega\left(f_{t}\right) \\ &=\sum_{i=1}^{n}\left[g_{i} f_{t}\left(\mathbf{x}_{i}\right)+\frac{1}{2} h_{i} f_{t}^{2}\left(\mathbf{x}_{i}\right)\right]+\gamma T+\frac{1}{2} \lambda\sum_{j=1}^Tw_j^2 \end{aligned} $$ いま $I_{j}:=\left\{i \mid q\left(\mathbf{x}_{i}\right)=j\right\}$ とおく.つまり $I_j$ は index $j$ の葉に辿りつくような入力 $\mathbf{x}_{i}$ の添え字 $i$ 全体である.これによって総和を括ることができ, $$ \tilde{\mathcal{L}}^{(t)}=\sum_{j=1}^{T}\left[\left(\sum_{i \in I_{j}} g_{i}\right) w_{j}+\frac{1}{2}\left(\sum_{i \in I_{j}} h_{i}+\lambda\right) w_{j}^{2}\right]+\gamma T $$ となる.
あとは $w_j$ で微分して極値を求めればよい.最適値 $w_j^{*}$ と損失 $\tilde{\mathcal{L}}^{(t)}$ は以下のとおり. $$ w_{j}^{*}=-\frac{\sum_{i \in I_{j}} g_{i}}{\sum_{i \in I_{j}} h_{i}+\lambda},\quad \tilde{\mathcal{L}}^{(t)}(q)=-\frac{1}{2} \sum_{j=1}^{T} \frac{\left(\sum_{i \in I_{j}} g_{i}\right)^{2}}{\sum_{i \in I_{j}} h_{i}+\lambda}+\gamma T. $$
e.g. $\mathcal{L}$ が平均二乗誤差 (MSE) のとき
損失関数 $\mathcal{L}:=\displaystyle\frac{1}{n}\sum_{i=1}^n(y_i-\hat{y}_i)^2$ のときを考える.$g_i=2(y_i-\hat{y}_i^{(t-1)}),h_i=-2$ であるから, $$ w_j^{*}=\frac{\sum_{i\in I_j}(y_i-\hat{y}_i^{(t-1)})}{|I_j|+\frac{1}{2}\lambda}. $$
木構造の更新
木の深さや葉の数を制限していたとしても,木構造の数は膨大ですべて調べることは現実的でない.代えて 1 枚の葉から始めて繰り返し枝を追加していく貪欲アルゴリズムが採用される.
ある葉ノードを分割して 2 つの葉 $L,R$ を接続することを考える.分割前後の損失関数の差分を $\mathcal{L}_\text{split}$ とおく.また $\mathcal{L}_{L},\mathcal{L}_{R}$ は$L,R$ それぞれに関する損失の増分である. $$ \begin{aligned} \mathcal{L}_{\text {split }} &=\mathcal{L}-\left(\mathcal{L}_{L}+\mathcal{L}_{R}\right) \\ &=\frac{1}{2}\left[\frac{\left(\sum_{i \in I_{L}} g_{i}\right)^{2}}{\sum_{i \in I_{L}} h_{i}+\lambda}+\frac{\left(\sum_{i \in I_{R}} g_{i}\right)^{2}}{\sum_{i \in I_{R}} h_{i}+\lambda}-\frac{\left(\sum_{i \in I} g_{i}\right)^{2}}{\sum_{i \in I} h_{i}+\lambda}\right]-\gamma \end{aligned} $$
$L,R$ を貪欲に探索し,損失がもっとも小さくなるような split を続けて木構造を更新していく.