scikit-learn準拠の学習器を作ってgrid searchとかcross validationする
Python Advent Calender 2014の19日目。
scikit-learnに準拠した学習器を自分で実装してscikit-learnに実装されているgrid searchとかcross validationを使えるようにするお話。Pythonの話というか完全にscikit-learnの話なんだけど、まあいいよね。
scikit-learnについてはこの辺がわかりやすいかな。
pythonの機械学習ライブラリscikit-learnの紹介
はじパタlt scikit-learnで始める機械学習
scikit-learn準拠にするには?
全部下のページに書いてある。
Contributing — scikit-learn 0.15.2 documentation
やること
- sklearn.base.BaseEstimatorを継承する
- 回帰ならRegressorMixinを(多重)継承する
- 分類ならClassifierMixinを(多重)継承する
- fitメソッドを実装する
- 学習データとラベルを受け取って学習したパラメータをフィールドにセットする
- initでパラメータをいじる操作を入れるとgrid searchが動かなくなる(後述)
- predictメソッドを実装する
- テストデータを受け取ってラベルのリストを返す
実装例(リッジ回帰)
scikit-learn-compatible Ridge Regression
Grid search
使用例
パラメータlamb
をうまく選んでやらないといけない。scikit準拠で実装したのでscikitに実装されてるgrid searchが使えて簡単にパラメータのバリデーションが出来る。
GridSearchCV
に学習器RidgeRegression()
とパラメータのリストparameters
とデータを何分割するかcv
を与える。それでfit
してbest_estimator_
を取るだけでスコアが最も高かったパラメータを持った学習器が手に入る。Grid searchなんてめんどくさい作業がたった4行で出来るなんて!
Grid search for oreore regression
結果
青が学習したい関数、青い点がgrid searchに使ったデータ、緑が学習した曲線。うまくパラメータ(ラムダ)が選ばれて過学習してない。やったね!
Cross validation
使用例
5-fold cross validationをする。簡単のためラムダは0にした。
Cross validationもものすごく簡単で、学習器、データ、cv
、スコアとして何を使うか(ここではMSE)を渡すだけ。結果としてcvの数だけ指定したスコアが返ってくる。たったの3行。
Cross validation for oreore regression
結果
$ python cross_validation.py -0.0552436216908
なぜかMSEなのに負の結果が帰って来て困惑したけどこれは仕様らしくて、下のページで議論されてる。どうやらMSEが-1倍された値が返ってくるらしい。
MSE is negative when returned by cross_val_score #2439
initでパラメータをいじるとダメ
例えば上で実装したRidgeRegressionの6行目を下のように変えるとgrid searchがエラーを吐いて動かなくなる。
self.lamb = 2*lamb
initではパラメータの代入しかしちゃいけないっぽい。公式にもこう書いてある。
As grid_search.GridSearchCV uses set_params to apply parameter setting to estimators, it is essential that calling set_params has the same effect as setting parameters using the init method. The easiest and recommended way to accomplish this is to not do any parameter validation in
__init__
. All logic behind estimator parameters, like translating string arguments into functions, should be done in fit.
まとめ
scikit-learn準拠の学習器を作ってgrid searchやcross validationをした。簡単なルールに従うだけでscikit-learn準拠に出来る。もちろんscikit-learnに元々入ってる学習器に対してもgrid searchやcross validationできる。超便利。