おいも貴婦人ブログ

生物系博士課程満期退学をしたAIエンジニアのブログ。

サモンのマップ化(Sammon mapping)

サモンのマップ化とは、高次元特徴ベクトルを持つデータを2次元にマッピングして、クラスタリングする方法です。誤差関数を以下のように定義する。

\(E=\frac{1}{\sum_i \sum_{j>i}d_{ij}} \sum_i \sum_{j>i}d_{ij} \frac{(d_{ij}-||{\bf r}_i -{\bf r}_j||)^2}{d_{ij}} \)
特徴ベクトルを\({\bf x}_i\)とすると、それを二次元にマップしたときの座標を\({\bf r}_i\)とする。\(d_{ij}=||{\bf x}_i - {\bf x}_j||\)である。この誤差関数を最小化するために、最急降下法を用いる。
\( {\bf r}_i(m+1)={\bf r}_i(m)-\alpha \frac{\frac{\partial E(m)}{\partial {\bf r}_i(m)}}{|\frac{\partial^2E(m)}{\partial {\bf r}_i^2(m)}|} \)
\(\alpha\)はパラメータで、0.3〜0.4くらいに設定する。このパラメータは、完全に経験的に決められるが、良い収束性を与えるのでマジック係数と呼ばれているらしい。おおざっぱなクラスタリングの可視化にサモンのマップ化は使われる。統計学機械学習の分野で有名なiris.dataを用い、Sammon mappingの正当性を評価しました。このデータには、3つのクラスが存在しますが、Sammon mappingでは2つのクラスにしか分類できていません。
iris4.dataは75行からなり、1行ずつに、一つのアヤメの分類と特徴が書いてあります。アヤメは3種類に分類され(25個ずつ)、4つの特徴があります。行の最初に、分類される番号、そして次に4つの特徴が書いてあります。
以下が結果。
f:id:oimokihujin:20140601085012g:plain
汚いコード。
networkxはネットワークグラフを書くためのパッケージです。詳しい使い方は下記のURLをご参照ください。
https://networkx.github.io
dist.cdistを使って、iris4.dataから読み取ったデータの距離行列(dmat)を作成する。2次元マップ(pos)を作成するためにrandomを使って、適当に配置する。dist.pdistで二次元マップ上における要素間のユークリッド距離を計算する。
参考文献
自己組織化マップ 改訂版

自己組織化マップ 改訂版

#! /usr/bin/env python
import scipy.spatial.distance as dist
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import random

random.seed(10)
data=[[float(j) for j in i.strip().split()[1:]] for i in open("iris4.data").readlines()]
dmat=dist.cdist(data,data)

pos={}
for i in range(len(data)):
    pos[i]=np.array([random.random()*10,random.random()*10])
def steep_des(number):
    fig=plt.figure()
    ax=fig.add_subplot(111)
    G=nx.Graph()
    c=0
    for i in range(len(data)):
        for j in range(i+1,len(data)):
            c+=dmat[i,j]
    alpha=0.3

    for i in range(len(pos)):
        dE=0
        ddE=[0,0]
        flag=True
        for j in range(len(pos)):
            dist_pos=dist.pdist([pos[i],pos[j]])
            if i==j or dist_pos==0:
                flag=False
                continue
            tmp1=(dmat[i,j]-dist_pos)/(dmat[i,j]*dist_pos)
            tmp2=(pos[i]-pos[j])
            dE+=-2/c*tmp1*tmp2
            for k in range(2):
                ddE[k]+=-(2/c)*tmp1*(1-tmp2[k]**2/dist_pos*(1/(dmat[i,j]-dist_pos)+1/dist_pos))
            #print tmp1,dist_pos,dmat[j,i]
            

        for j in range(2):
            pos[i][j]=pos[i][j]-alpha*dE[j]/abs(ddE[j])

    nx.draw_networkx_nodes(G,pos,nodelist=range(25),with_labels=False,node_color='r')
    nx.draw_networkx_nodes(G,pos,nodelist=range(26,50),with_labels=False,node_color='b')
    nx.draw_networkx_nodes(G,pos,nodelist=range(51,75),with_labels=False,node_color='g')

    plt.axis('off')
    num="%03d" % number
    plt.savefig("sammon"+num+".png")


for i in range(30):
    steep_des(i)
print pos

mapdata=[]

iris4.data

0 5.1 3.5 1.4 0.2
0 4.9 3.0 1.4 0.2
0 4.7 3.2 1.3 0.2
0 4.6 3.1 1.5 0.2
0 5.0 3.6 1.4 0.2
0 5.4 3.9 1.7 0.4
0 4.6 3.4 1.4 0.3
0 5.0 3.4 1.5 0.2
0 4.4 2.9 1.4 0.2
0 4.9 3.1 1.5 0.1
0 5.4 3.7 1.5 0.2
0 4.8 3.4 1.6 0.2
0 4.8 3.0 1.4 0.1
0 4.3 3.0 1.1 0.1
0 5.8 4.0 1.2 0.2
0 5.7 4.4 1.5 0.4
0 5.4 3.9 1.3 0.4
0 5.1 3.5 1.4 0.3
0 5.7 3.8 1.7 0.3
0 5.1 3.8 1.5 0.3
0 5.4 3.4 1.7 0.2
0 5.1 3.7 1.5 0.4
0 4.6 3.6 1.0 0.2
0 5.1 3.3 1.7 0.5
0 4.8 3.4 1.9 0.2
1 7.0 3.2 4.7 1.4
1 6.4 3.2 4.5 1.5
1 6.9 3.1 4.9 1.5
1 5.5 2.3 4.0 1.3
1 6.5 2.8 4.6 1.5
1 5.7 2.8 4.5 1.3
1 6.3 3.3 4.7 1.6
1 4.9 2.4 3.3 1.0
1 6.6 2.9 4.6 1.3
1 5.2 2.7 3.9 1.4
1 5.0 2.0 3.5 1.0
1 5.9 3.0 4.2 1.5
1 6.0 2.2 4.0 1.0
1 6.1 2.9 4.7 1.4
1 5.6 2.9 3.6 1.3
1 6.7 3.1 4.4 1.4
1 5.6 3.0 4.5 1.5
1 5.8 2.7 4.1 1.0
1 6.2 2.2 4.5 1.5
1 5.6 2.5 3.9 1.1
1 5.9 3.2 4.8 1.8
1 6.1 2.8 4.0 1.3
1 6.3 2.5 4.9 1.5
1 6.1 2.8 4.7 1.2
1 6.4 2.9 4.3 1.3
2 6.3 3.3 6.0 2.5
2 5.8 2.7 5.1 1.9
2 7.1 3.0 5.9 2.1
2 6.3 2.9 5.6 1.8
2 6.5 3.0 5.8 2.2
2 7.6 3.0 6.6 2.1
2 4.9 2.5 4.5 1.7
2 7.3 2.9 6.3 1.8
2 6.7 2.5 5.8 1.8
2 7.2 3.6 6.1 2.5
2 6.5 3.2 5.1 2.0
2 6.4 2.7 5.3 1.9
2 6.8 3.0 5.5 2.1
2 5.7 2.5 5.0 2.0
2 5.8 2.8 5.1 2.4
2 6.4 3.2 5.3 2.3
2 6.5 3.0 5.5 1.8
2 7.7 3.8 6.7 2.2
2 7.7 2.6 6.9 2.3
2 6.0 2.2 5.0 1.5
2 6.9 3.2 5.7 2.3
2 5.6 2.8 4.9 2.0
2 7.7 2.8 6.7 2.0
2 6.3 2.7 4.9 1.8
2 6.7 3.3 5.7 2.1