データセットを分割するときに scikit-learn の train_test_split()
をよく使う.今回は train_test_split()
に設定できる stratify
パラメータを試す.stratify
は「層化」という意味で「データセットの特性を考慮した分割」とも言える.特に「不均衡データセット」を使うときに重要になる.
train_test_split()
をデフォルト設定で使う
train_test_split()
のデフォルト設定を抜粋すると以下のようになる.stratify
はデフォルトで None
になる.
train_size = 0.75
(トレーニングデータ 75 %)test_size = 0.25
(テストデータ 25 %)shuffle = True
(ランダムに分割する)stratify = None
(層化なし)
例として,scikit-learn に組み込まれた「ワインデータセット🍷」を使う.正解ラベルは 0
と 1
と 2
の「計3種類」あり,それぞれの分布は以下のようになっている.ワインデータセットは不均衡ではないけど,検証には使える.
- 178 件
- 正解ラベル 0 : 59 件(33.1 %)
- 正解ラベル 1 : 71 件(39.8 %)
- 正解ラベル 2 : 48 件(26.9 %)
train_test_split(X, y)
のように「デフォルト設定のまま」実行するように以下のコードを書いた.3回実行したところ「正解ラベル 0」だと 0.34 %
→ 0.32 %
→ 0.36 %
と推移した.もし不均衡データセットを使うと,より顕著に差が出る可能性がある.
import numpy as np from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split wine = load_wine() X = wine.data y = wine.target def calc(i, y): print(str(i+1) + '回目 (' + str(len(y_train)) + '件)') print('・Label 0 : ' + str(np.round((y == 0).sum()/len(y), 2)) + ' %') print('・Label 1 : ' + str(np.round((y == 1).sum()/len(y), 2)) + ' %') print('・Label 2 : ' + str(np.round((y == 2).sum()/len(y), 2)) + ' %') print('---') for i in range(3): X_train, X_test, y_train, y_test = train_test_split(X, y) calc(i, y_train) # 1回目 (133件) # ・Label 0 : 0.34 % # ・Label 1 : 0.38 % # ・Label 2 : 0.28 % # --- # 2回目 (133件) # ・Label 0 : 0.32 % # ・Label 1 : 0.41 % # ・Label 2 : 0.27 % # --- # 3回目 (133件) # ・Label 0 : 0.36 % # ・Label 1 : 0.36 % # ・Label 2 : 0.28 % # ---
train_test_split()
で stratify
パラメータを使う
次にtrain_test_split(X, y, stratify=y)
のように stratify
パラメータを使って実行する.正解ラベル y
を前提に「層化」するため,分布を維持しながらデータセットを分割することができる.なるほど💡
for i in range(3): X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y) calc(i, y_train) # 1回目 (133件) # ・Label 0 : 0.33 % # ・Label 1 : 0.4 % # ・Label 2 : 0.27 % # --- # 2回目 (133件) # ・Label 0 : 0.33 % # ・Label 1 : 0.4 % # ・Label 2 : 0.27 % # --- # 3回目 (133件) # ・Label 0 : 0.33 % # ・Label 1 : 0.4 % # ・Label 2 : 0.27 % # ---
まとめ
今回は scikit-learn の train_test_split()
で「層化サンプリング」ができる stratify
パラメータを試した.覚えておこう!