苦労する遊び人の玩具箱
latest
  • OS 関係
  • プログラミング言語
    • C: 糖衣構文を備えたメモリ
    • CSS Tips
    • Haskell:
    • Java Script: プロトタイプベース言語
    • julia : 科学技術計算用言語
    • Python: 最もチャーミングなプログラミング言語
      • ライブラリ関係
        • PyInstaller: freeze 化
        • Cython : C との融合による高速化
        • Django : データベース中心の Web アプリ開発ツール
        • ggplot: 描画作図
        • inquirer: 対話的コマンドライン
        • Ipython : python 用の対話式シェル
        • matplotlib: 可視化用ライブラリ
        • Pandas : R ライクなデータ操作
        • pip-tool : pip で管理しているライブラリの管理を楽にする
        • pyenv: python 用仮想環境
        • pylean2 : Deep Learning の python 実装
        • pySide : python 用 GUI 作成ライブラリ
        • pyslack: slack API ラッパー
        • Scikit-learn : 機械学習ライブラリ
        • Seaborn
        • segEval: セグメンテーションチェック
        • Sphinx:ドキュメントビルダ
        • StatsModels : 統計モデル
        • Trac: プロジェクト管理システム
        • watchdog : ファイル監視ライブラリ
        • word2vec
        • xlsx2csv : xlsx ファイルを CVS に変換
      • Tips
    • R : 統計処理
  • Tool 関係
  • 読書ログ
  • 研究関係
苦労する遊び人の玩具箱
  • Docs »
  • プログラミング言語 »
  • Python: 最もチャーミングなプログラミング言語 »
  • Scikit-learn : 機械学習ライブラリ »
  • Grid Search: パラメータチューニング
  • Edit on Bitbucket

Grid Search: パラメータチューニング¶

Last Change: 15-Jan-2016.
author : qh73xe

このページでは sklearn におけるパラメータチューニングの方法について記述します.

GridSearch¶

GridSearch とはようは機械学習のアルゴリズムで設定できるパラメータを 絨毯爆撃を行い最適なパラメータを得る手法です.

こういうと,とてもシンプルでナイーブな方法に思えますが, 結構入り組んだ for 文を書くことになるので,自分で実装するのは結構手間がかかります.

sklearn ではこの GridSearch を気軽に行うための関数が用意されているのでそれを使用 するのが便利です.

How to Use¶

sklearn で GridSearch を行うには GridSearchCV という関数を使用します. 以下に SVM を使用した場合でのグリッドサーチのサンプルを記述します.

from sklearn import datasets  # サンプル用のデータ・セット
from sklearn.grid_search import GridSearchCV
from sklearn.svm import SVC  # SVM の実行関数
from sklearn.cross_validation import train_test_split  # 訓練データとテストデータを分ける関数
from sklearn.metrics import classification_report, confusion_matrix  # 学習結果要約用関数

# サンプル用のデータを読み込み
digits = datasets.load_digits()
n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)


# 探索するパラメータを設定
param_grid = [
    {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
    {'C': [1, 1 0, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
]

# 評価関数を指定
scores = ['accuracy', 'precision', 'recall']

# 各評価関数ごとにグリッドサーチを行う
score in scores:
    print score
    clf = GridSearchCV(SVC(C=1), param_grid, cv=5, scoring=score, n_jobs=-1)  # n_jobs: 並列計算を行う(-1 とすれば使用PCで可能な最適数の並列処理を行う)
    clf.fit(X_train, y_train)

    print clf.best_estimator_  # 最適なパラメータを表示

    for params, mean_score, all_scores in clf.grid_scores_:
        print "{:.3f} (+/- {:.3f}) for {}".format(mean_score, all_scores.std() / 2, params)

    # 最適なパラメータのモデルでクラスタリングを行う
    y_true, y_pred = y_test, clf.predict(X_test)
    print classification_report(y_true, y_pred)  # クラスタリング結果を表示
    print confusion_matrix(y_true, y_pred)       # クラスタリング結果を表示

上記のコードの概説をします. とりあえず最初の数行はライブラリのインポートを行っています. それぞれの関数がどのようなものなのかはコメントに記述しているので省略しますが, GridSearchCV と 機械学習のアルゴリズムが実装されている関数(今回の場合 SVC)が最低限必要です.

サンプル用のデータを読み込んでいる部分は基本的に無視してしまってよいです. ただし,訓練用のデータと,テスト用のデータを分けておくことはよくある話なので train_test_split 関数は知っておいて損はないかと思います.

グリッドサーチは前述の通り基本的に絨毯爆撃であるので, それぞれのパラメータを設定しておく必要があります. パラメータの設定には以下の構文を使用します.

  • 変数名はどうでもよいのですが
# 探索するパラメータを設定
param_grid = [
    {'C': [1, 10, 100, 1000], 'kernel': ['linear']},
    {'C': [1, 1 0, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
]

注釈

アルゴリズムのパラメータを調べるには

sklearn で使用できる関数のパラメータを取得するには以下の構文を使用します.

estimator.get_params()

実際にグリッドサーチを行うためには以下のようにします. 第一引数には使用したいアルゴリズムのインスタンスをおきます. cv はグリッドサーチに使用するクロスバリデーションの分割数です. scoring には結局良いモデルとは何なのかの評価指標をいれていくことになります.

clf = GridSearchCV(SVC(C=1), param_grid, cv=5, scoring=score, n_jobs=-1)
clf.fit(X_train, y_train)

ここで注意が必要なのは先に fit を行わないと最適なパラメータを取得できないことです.

注釈

scoring

scoring はデフォルトで以下の値が設定されます.

  • クラスタリング関係のアルゴリズムでは sklearn.metrics.accuracy_score
  • 回帰系のアルゴリズムでは sklearn.metrics.r2_score

その他使用可能な評価指標に関しては以下のページを確認してください

  • http://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter

最後に返り値に関して説明しておきます.

grid_scores_ : 名前付きタプルのリスト
  • param_grid に含まれるすべてのパラメータの得点

  • それぞれのエントリごとのパラメータ設定に対応しています

  • 名前付きタプルは以下の属性を持っています

    • parameters: パラメータ設定の辞書
    • mean_validation_score: クロスバリデーションの平均スコア
    • cv_validation_scores: クロスバリデーションそれぞれの試行ごとのスコア
best_estimator_ : 分類器
  • グリッドサーチの結果最も得点の高い(あるいはエラーレート等を評価指標に選んだ場合には低い)パラメータの分類器です
  • refit=False になっている場合には使用できないので注意です
best_score_ : float
best_estimator の得点
data.best_params_ : dict
最適バラメータの辞書
data.scorer_ : function
最適バラメータを探した際に使用した評価関数
Next Previous

© Copyright 2014, qh73xe. Revision 6bce65a0.

Built with Sphinx using a theme provided by Read the Docs.