パーセプトロンの学習規則

線形分離可能な2クラスの識別問題をパーセプトロンの学習規則を使って解く。

(NN法について)
1. NN法(Nearest neighborhood:最近傍法)とは、
各クラスw1,w2,..に属するプロトタイプ(代表点)p1,p2,..に対して
入力された特徴ベクトルxとの距離が最小となるプロトタイプを用いる手法
(プロトタイプpiはクラスのラベルとも見れる))

2. NN法の識別関数は
 g(\mathbf{x}) =\mathbf{w}^t \mathbf{x}とかける。
なぜなら、プロトタイプ p_iと特徴ベクトル xとの距離は
 D(x,p_i)=||x-p_i||であって、この2乗は
 ||x-p_i||^2 = ||\mathbf{x}||^2 -2(\mathbf{p_i}^t \mathbf{x} - \frac{1}{2} ||p_i||^2 )
距離Dを最小化することは、すなわちカッコの中身を最大化することなので
 g_i(\mathbf{x}) = \mathbf{p_i}^t \mathbf{x} - \frac{1}{2} ||p_i||^2
第一項はベクトルxに対する定数係数、第二項はxに依存しない定数項なので
 g_i(\mathbf{x})=\sum_{j=0}^d w_{ij} x_j (ただし x0=1)と書いてよい。
よって、 g(\mathbf{x}) =\mathbf{w}^t \mathbf{x}

3. 識別関数の更新
2クラス識別問題について。
クラスw1に属するx1に対し、g(x1)>0
クラスw2に属するx2に対し、g(x2)<0
となるように識別関数を設定したい。
誤識別が起こった場合の更新規則は、
g(x1)< 0 => w' = w + ρx
g(x2) >0 => w' = w - ρx
である。
これを、誤識別が起こらなくなるまで繰り返す。

以上をパーセプトロンの学習規則という。
特徴空間上の学習データが超平面で分割可能(線形分離可能)であれば、
パーセプトロンの学習規則は有限回の繰り返しで終了する。
これをパーセプトロンの収束定理という。

それでは実際の問題。
ターゲットとして2次元正規分布を生成する。

# coding: utf-8

import numpy as np
import matplotlib.pyplot as plt
from pylab import *

# サンプルデータの生成(2次元正規分布)
mean1 = [0,0]; cov1 = [[4,0],[0,100]]; N1 = 1000
X1 = np.random.multivariate_normal(mean1,cov1,N1)
mean2 = [30,-30]; cov2 = [[1,20],[20,50]]; N2 = 1000
X2 = np.random.multivariate_normal(mean2,cov2,N2)
X = np.concatenate((X1,X2))
# 描画
x,y = X.T
plt.plot(x,y,'k.'); plt.axis('equal'); #plt.show()

#識別関数 g(x,y)=w0 + w1*x + w2*y
#初期値
w0 = 60
w1 = 4
w2 = -1.0
rho = 0.001

#パーセプトロンの学習規則を適用
i=0
while 1:
    #クラス1に対する識別
    for x1 in X1:        
        g = w0 + w1 * x1[0] + w2 * x1[1]
        if g < 0:
            w0 = w0 + rho*1
            w1 = w1 + rho*x1[0]
            w2 = w2 + rho*x1[1]
            i += 1
    #クラス2に対する識別
    for x2 in X2:
        g = w0 + w1 * x2[0] + w2 * x2[1]
        if g > 0:
            w0 = w0 - rho*1
            w1 = w1 - rho*x2[0]
            w2 = w2 - rho*x2[1]
            i += 1
    if i==0: 
        break
    i = 0 

#得られた線形識別関数を描画
xsamp = linspace(-20,60,1000)
ysamp = -1.0*w1/w2 * xsamp -1.0 * w0/w2
plt.plot(xsamp,ysamp)
plt.show()

実行結果