乱れた森林再生委員会(学習集会)

What I cannot create, I do not understand.(作ることができないなら,理解できていないということだ) - R. Feynman(物理学者)

木を育てる その2:分割を繰り返して適当なところで止める(ジコリューで行こう!#2)

こんにちは,植木等*1です。

わかっちゃいるけどやめられない。

f:id:tatamiyatamiyatatatamiya:20190224150754p:plain

この記事では前回に引き続き,決定木を自己流で実装していきたいと思います。

tatamiya.hatenablog.com

前回は不純度が最低になるような分割の探索を行いましたので, 今回はこれを繰り返し,適当なところで止めてクラス分類を行いたいと思います。

指針

前回までのあらすじと今回やること

与えられた特徴量ベクトルXとtargetベクトルyを元に,分割する特徴量・各特徴量の値で全探索し,不純度が最低になる分割方法見つけ出すところまでを,以下の関数にまとめました。

途中の関数find_optimal_divisionは,各特徴量について最適な分割点を探索するものです。

def divide_tree(X, y):

    results = np.apply_along_axis(find_optimal_division, 0, X, y)

    arg_div = np.argmin(results[1])
    x_div = results[0, arg_div]

    return arg_div, x_div

返り値のうち,arg_divは分割する特徴量のindexを,x_divは分割の際の閾値です。

分割の指標となる不純度の計算には,Gini係数を用いました。

今度はこの最適な分割後の断片2つそれぞれに再度分割を施していき,より不純度を下げていきます。 これには,先ほどの関数divide_tree再帰的に呼び出していきます。

サンプルのデータとしては,引き続きirisを用います。

どこで分割を止めるか?

途中で不純度が十分下がりきった場合はそこで分割を止めます。

実際,一回分割した後の断片を見てみると,片方には0番目のクラス(品種名:Setosa)しか含まれていないことがわかります。

y[X[:, arg_div] <= x_div]
# array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

これ以上分割しても意味がないので,「不純度が一定値以下になったら分割をやめる」という条件を設けます。

また,データを細かく分けすぎてしまってもいけません。学習用にInputしたデータに過度に依存した分割ができてしまい,新しいデータを入れて予測を行う際に不適切な分割を行うことで精度が下がる恐れがあります。(過学習

そのような事態を避けるためにも,

  • 分割後の断片の大きさ
  • 分割を行う回数(深さ)

閾値を設け,適当なところで止めることにします。

実装

再帰呼び出しにより分割後の断片を再分割していく

まず,全体のイメージ的にはこんな実装になります。

def go_on_dividing_temp1(X, y):

    arg_div, x_div = divide_tree(X, y)
    
    mask = X[:, arg_div] > x_div
    X_right, X_left = X[mask], X[~mask]
    y_right, y_left = y[mask], y[~mask]
    
    go_on_dividing_tmp1(X_left, y_left)
    go_on_dividing_tmp1(X_right, y_right)

このgo_on_dividing_temp1関数では,最初に一回分割を行なってから,得られた特徴量が閾値より大きいか小さいかで右断片と左断片に分けます。

この際に,閾値以上のデータはTrue, それ以下のものはFalseとするbinaryのmaskを作成し,これをX, yの引数に撮りました。leftの方で~maskとしていますが,~はTrue, Falseを逆転させます。

そして,分けた断片を再度別々にgo_on_dividing_temp1関数に打ち込みます。

定義した関数の中でその関数自身を呼び出すこのやり方は,再帰呼び出しと呼ばれます。 最初は何をやっているか理解が難しいかもしれませんが,何度か動かしていくうちに徐々に慣れてくると思います。

分割の停止条件を設ける

これで枝葉をどんどん伸ばしていくことができますが,このままでは分割できるデータがなくなるまで延々と分割を続けていってしまいます。

そのため,「指針」で触れた下記の3つの停止条件を付加します:

  • 不純度が一定値以下
    • 今回は各断片ごとのGini係数に閾値を設けます。
  • 分割後の断片の大きさ
    • 分割後の断片の長さlen(y_left), len(y_right)閾値を設けます。
  • 分割の回数
    • 分割の深さをdepthとして変数に入れておき,こちらに閾値を設けます。

以上を実装すると,以下のようになります。

def go_on_dividing_temp2(X, y, depth=1,
                   threshold_gini=0.05, min_node_size=5, max_depth=3):        
    
    arg_div, x_div = divide_tree(X, y)
        
    mask = X[:, arg_div] > x_div
    X_right, X_left = X[mask], X[~mask]
    y_right, y_left = y[mask], y[~mask]
    
    gini_left = gini(y_left)
    gini_right = gini(y_right)
    
    list_divided = [(X_left, y_left, gini_left), (X_right, y_right, gini_right)]
    
    for divided in list_divided:
        
        X_i, y_i, gini_i = divided

        if gini_i > threshold_gini and len(y_i)>min_node_size and depth < max_depth:
            go_on_dividing_temp2(X_i, y_i, depth=depth+1)
        else:
            class_decided = np.bincount(y_i).argmax()

Gini係数,分割後断片のデータ長さ,分割の深さの閾値をそれぞれthreshold_gini, min_node_size, max_depthとしました。 デフォルト引数でとりあえず適当な値を入れてあります。

分割の深さdepthgo_on_dividing_temp2関数の引数として持っておき,デフォルトの値を1,呼び出す際に一つインクリメントするようにしました。

これらを合わせて,以下の条件が全部満たされている間だけ分割を続けるようにしてあります。

  • gini_i > threshold_gini
  • len(y_i)>min_node_size
  • depth < max_depth

y_i, gini_iには,分割後断片左右どちらかの正解クラスベクトルとGini係数が入ります。

もし3つの条件のうちどれか一つでも満たされない場合,分割をやめます。

そしてそれに合わせて,分割されたデータ断片がどのクラスに分類されるかを決めます。

ここでは,断片に含まれるクラスのうち最も多数派を占めるものを選び出します。

class_decided = np.bincount(y_i).argmax()

分割の様子を出力する

これで一通りの計算はできたのですが,このまま実行しても何も出力されません。

  • どのような分割を経て,
  • 最終的にどのクラスに行き着いたか

を出力するにはどうすればいいでしょうか?

ここの部分は少し頭を使う必要があったのですが,最終的に次のようにしました:

def go_on_dividing(X, y, depth=1, div_set=None,
                   threshold_gini=0.05, min_node_size=5, max_depth=3):
    
    global i_node
    if div_set is None:
        div_set = []
    
    i_node_current = i_node
    
    arg_div, x_div = divide_tree(X, y)
    
    print("=== node {} (depth {}): arg_div -> {}, x_div -> {} ===".format(i_node, depth, arg_div, x_div))
    
    mask = X[:, arg_div] > x_div
    X_right, X_left = X[mask], X[~mask]
    y_right, y_left = y[mask], y[~mask]
    
    gini_left = gini(y_left)
    gini_right = gini(y_right)
    
    list_divided = [(X_left, y_left, gini_left), (X_right, y_right, gini_right)]
    
    for lr, divided in enumerate(list_divided):

        div_set_tmp = div_set.copy()
        div_set_tmp.append((depth, i_node_current, arg_div, x_div, lr))
        i_node +=1
        
        X_i, y_i, gini_i = divided
        if gini_i > threshold_gini and len(y_i)>min_node_size and depth < max_depth:
            go_on_dividing(X_i, y_i, depth=depth+1, div_set=div_set_tmp)
        else:
            class_decided = np.bincount(y_i).argmax()
            print(div_set_tmp, class_decided)

分割ごとに固有の通し番号i_nodeを振り,グローバル変数として持っておくことにしました。 Node(ノード)というのは,「節」や「結節点」という意味で, ここでは分割・もしくはクラス判定を行う部分を指します。

さらに,分割の経過をリストdiv_setに保存していきます。 この中には,

  • 何回目の分割(depth)の,
  • 何番目のノード(i_node)で,
  • どの特徴量index(arg_div)を
  • 閾値いくつ(x_div)で分割して,
  • 左側・右側どちらの断片に入ったか?(lr)

の情報を入れていき,go_on_dividing関数の引数に入れて下流のノードに伝達していきます。

最終的に分割を止める際には,このdiv_setと,分類後のクラス番号を表示します。

これらを実行すると以下のようになります。

i_node = 0
go_on_dividing(X, y)
'''
=== node 0 (depth 1): arg_div -> 2, x_div -> 1.9 ===
[(1, 0, 2, 1.9, 0)] 0
=== node 2 (depth 2): arg_div -> 3, x_div -> 1.7 ===
=== node 3 (depth 3): arg_div -> 2, x_div -> 4.9 ===
[(1, 0, 2, 1.9, 1), (2, 2, 3, 1.7, 0), (3, 3, 2, 4.9, 0)] 1
[(1, 0, 2, 1.9, 1), (2, 2, 3, 1.7, 0), (3, 3, 2, 4.9, 1)] 2
[(1, 0, 2, 1.9, 1), (2, 2, 3, 1.7, 1)] 2
'''

ノード番号i_nodeはグローバル引数で定義したので, 呼び出しの前に初期化しておく必要があります。

なお,リストdiv_setを更新する際に,

div_set_tmp = div_set.copy()
div_set_tmp.append((depth, i_node_current, arg_div, x_div, lr))

のように,いったんコピーdiv_set_tmpを作成してから更新・go_on_dividing関数に渡しています。

これをせずにdiv_setにそのままappendして引数に渡してしまうと,全ノードでこのリストが共有されてしまいます。 そのため,最終的な分類結果を表示させるときに,自分がたどっていないノードでの分類過程まで出力されてしまいます。


ノード番号を振る際にグローバル変数を使っている点が気になりますが, クラスにしてまとめる際にうまく処理できればと思います。

なお,今回の実装ではノード番号は計算を行なった順に振っていますが, 同じ深さのノード間では左から順に1ずつ増えていくように振る方がメジャーなようです (参考:みんな大好きはじめてのパターン認識 第11章)*2。 これに従ってノード番号を振るにはどうすれば良いかは,現在考え中です。

ともあれ以上で一通り決定木の学習のプロセスが実装できました。 次回はこれをもとに新しいデータをいれてクラスラベル(品種)を予測できるようにしたいと思います。

*1:1960年代にヒットしたコメディーバンド ハナ肇クレイジーキャッツのボーカル。代表曲はスーダラ節,ハイそれまでョなど。2007年に死去。他のメンバーも気になって調べてみたところ,犬塚弘のみご存命(90歳)であった。

*2:個人的には数式の導入がそっけないのでそこまで好きではない。ただ,第11章の決定木〜アンサンブル学習はコンパクトにまとまっているので参考にしている。