kakakakakku blog

Weekly Tech Blog: Keep on Learning!

train_test_split() の stratify パラメータを使って層化サンプリングをする

データセットを分割するときに scikit-learntrain_test_split() をよく使う.今回は train_test_split() に設定できる stratify パラメータを試す.stratify「層化」という意味で「データセットの特性を考慮した分割」とも言える.特に「不均衡データセット」を使うときに重要になる.

scikit-learn.org

train_test_split() をデフォルト設定で使う

train_test_split() のデフォルト設定を抜粋すると以下のようになる.stratify はデフォルトで None になる.

  • train_size = 0.75(トレーニングデータ 75 %)
  • test_size = 0.25(テストデータ 25 %)
  • shuffle = True(ランダムに分割する)
  • stratify = None(層化なし)

例として,scikit-learn に組み込まれた「ワインデータセット🍷」を使う.正解ラベルは 012「計3種類」あり,それぞれの分布は以下のようになっている.ワインデータセットは不均衡ではないけど,検証には使える.

  • 178 件
    • 正解ラベル 0 : 59 件(33.1 %)
    • 正解ラベル 1 : 71 件(39.8 %)
    • 正解ラベル 2 : 48 件(26.9 %)

scikit-learn.org

train_test_split(X, y) のように「デフォルト設定のまま」実行するように以下のコードを書いた.3回実行したところ「正解ラベル 0」だと 0.34 %0.32 %0.36 % と推移した.もし不均衡データセットを使うと,より顕著に差が出る可能性がある.

import numpy as np
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

wine = load_wine()
X = wine.data
y = wine.target

def calc(i, y):
        print(str(i+1) + '回目 (' + str(len(y_train)) + '件)')
        print('・Label 0 : ' + str(np.round((y == 0).sum()/len(y), 2)) + ' %')
        print('・Label 1 : ' + str(np.round((y == 1).sum()/len(y), 2)) + ' %')
        print('・Label 2 : ' + str(np.round((y == 2).sum()/len(y), 2)) + ' %')
        print('---')

for i in range(3):
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    calc(i, y_train)

# 1回目 (133件)
# ・Label 0 : 0.34 %
# ・Label 1 : 0.38 %
# ・Label 2 : 0.28 %
# ---
# 2回目 (133件)
# ・Label 0 : 0.32 %
# ・Label 1 : 0.41 %
# ・Label 2 : 0.27 %
# ---
# 3回目 (133件)
# ・Label 0 : 0.36 %
# ・Label 1 : 0.36 %
# ・Label 2 : 0.28 %
# ---

train_test_split()stratify パラメータを使う

次にtrain_test_split(X, y, stratify=y) のように stratify パラメータを使って実行する.正解ラベル y を前提に「層化」するため,分布を維持しながらデータセットを分割することができる.なるほど💡

for i in range(3):
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
    calc(i, y_train)

# 1回目 (133件)
# ・Label 0 : 0.33 %
# ・Label 1 : 0.4 %
# ・Label 2 : 0.27 %
# ---
# 2回目 (133件)
# ・Label 0 : 0.33 %
# ・Label 1 : 0.4 %
# ・Label 2 : 0.27 %
# ---
# 3回目 (133件)
# ・Label 0 : 0.33 %
# ・Label 1 : 0.4 %
# ・Label 2 : 0.27 %
# ---

まとめ

今回は scikit-learntrain_test_split()「層化サンプリング」ができる stratify パラメータを試した.覚えておこう!