でかいチーズをベーグルする

でかいチーズはベーグルすべきです。

scikit-learn準拠の学習器を作ってgrid searchとかcross validationする

Python Advent Calender 2014の19日目。

scikit-learnに準拠した学習器を自分で実装してscikit-learnに実装されているgrid searchとかcross validationを使えるようにするお話。Pythonの話というか完全にscikit-learnの話なんだけど、まあいいよね。

scikit-learnについてはこの辺がわかりやすいかな。

pythonの機械学習ライブラリscikit-learnの紹介

はじパタlt scikit-learnで始める機械学習

scikit-learn準拠にするには?

全部下のページに書いてある。

Contributing — scikit-learn 0.15.2 documentation

やること

  1. sklearn.base.BaseEstimatorを継承する
  2. 回帰ならRegressorMixinを(多重)継承する
  3. 分類ならClassifierMixinを(多重)継承する
  4. fitメソッドを実装する
    • 学習データとラベルを受け取って学習したパラメータをフィールドにセットする
    • initでパラメータをいじる操作を入れるとgrid searchが動かなくなる(後述)
  5. predictメソッドを実装する
    • テストデータを受け取ってラベルのリストを返す

実装例(リッジ回帰)

scikit-learn-compatible Ridge Regression

Grid search

使用例

パラメータlambをうまく選んでやらないといけない。scikit準拠で実装したのでscikitに実装されてるgrid searchが使えて簡単にパラメータのバリデーションが出来る。

GridSearchCVに学習器RidgeRegression()とパラメータのリストparametersとデータを何分割するかcvを与える。それでfitしてbest_estimator_を取るだけでスコアが最も高かったパラメータを持った学習器が手に入る。Grid searchなんてめんどくさい作業がたった4行で出来るなんて!

Grid search for oreore regression

結果

f:id:yamaguchiyuto:20141205140900p:plain

青が学習したい関数、青い点がgrid searchに使ったデータ、緑が学習した曲線。うまくパラメータ(ラムダ)が選ばれて過学習してない。やったね!

Cross validation

使用例

5-fold cross validationをする。簡単のためラムダは0にした。

Cross validationもものすごく簡単で、学習器、データ、cv、スコアとして何を使うか(ここではMSE)を渡すだけ。結果としてcvの数だけ指定したスコアが返ってくる。たったの3行。

Cross validation for oreore regression

結果

$ python cross_validation.py
-0.0552436216908

なぜかMSEなのに負の結果が帰って来て困惑したけどこれは仕様らしくて、下のページで議論されてる。どうやらMSEが-1倍された値が返ってくるらしい。

MSE is negative when returned by cross_val_score #2439

initでパラメータをいじるとダメ

例えば上で実装したRidgeRegressionの6行目を下のように変えるとgrid searchがエラーを吐いて動かなくなる。

self.lamb = 2*lamb

initではパラメータの代入しかしちゃいけないっぽい。公式にもこう書いてある。

As grid_search.GridSearchCV uses set_params to apply parameter setting to estimators, it is essential that calling set_params has the same effect as setting parameters using the init method. The easiest and recommended way to accomplish this is to not do any parameter validation in __init__. All logic behind estimator parameters, like translating string arguments into functions, should be done in fit.

まとめ

scikit-learn準拠の学習器を作ってgrid searchやcross validationをした。簡単なルールに従うだけでscikit-learn準拠に出来る。もちろんscikit-learnに元々入ってる学習器に対してもgrid searchやcross validationできる。超便利。