おいも貴婦人ブログ

生物系博士課程満期退学をしたAIエンジニアのブログ。

ニューラルネット(Python)

Newral network(NN)を組んでみました。統計学機械学習の分野で有名なiris.dataを用い、NNの正当性を評価しました。iris.dataは150行からなり、1行ずつに、一つのアヤメの分類と特徴が書いてあります。アヤメは3種類に分類され、4つの特徴があります。行の最初に、分類される番号、そして次に4つの特徴が書いてあります。このプログラムでは、149個のアヤメのデータを使って、残り1つのアヤメを分類します。

#! /usr/bin/env python
import numpy as np
import scipy.spatial.distance as dist

class newral_network:
    def __init__(self):
        self.data=[]
        self.w_first=[]
        self.w_second=[]
        self.i_layer=[]
        self.h_layer=[]
        self.o_layer=[]
        self.teacher=[]
        self.eta=0.7
        self.num_chara=0
        self.num_class=0


    def read_file(self,readfile):
        self.data=np.loadtxt(readfile)
        self.num_chara=int(len(self.data[0])-1)
        self.num_class=int(self.data[-1][0]+1)
        #self.num_class=5
        self.w_first=np.ones((self.num_chara,self.num_chara))
        self.w_second=np.ones((self.num_chara,self.num_class))
        self.t_mat=np.identity(self.num_class)
        print "the number of charactor and class:",self.num_chara,",",self.num_class

    def activation_func(self,x):
        return 1/(1+np.exp(-x))
    
    def back_propagation(self,ref_num):
        deltak=np.zeros(self.num_class)
        dw_first=np.zeros((self.num_chara,self.num_chara))
        dw_second=np.zeros((self.num_chara,self.num_class))

        for k in range(self.num_class):
            deltak[k]=(self.t_mat[ref_num,k]-self.o_layer[k])*self.o_layer[k]*(1-self.o_layer[k])
            for j in range(self.num_chara):

                dw_second[j,k]=self.eta*self.h_layer[j]*deltak[k]
        self.w_second+=dw_second
        for j in range(self.num_chara):
            deltaj=self.h_layer[j]*(1-self.h_layer[j])*np.dot(self.w_second[j],deltak)
            for i in range(self.num_chara):
                dw_first[i,j]=self.eta*self.i_layer[i]*deltaj
        self.w_first+=dw_first

    def foward(self):
        self.h_layer=self.activation_func(np.dot(self.i_layer,self.w_first))
        self.o_layer=self.activation_func(np.dot(self.h_layer,self.w_second))

    def error(self,t_num):
        return dist.pdist((self.o_layer,self.t_mat[t_num]),'euclidean')/2.0
        
    def main(self):
        num_leaning=149
        for j in range(4):
            for i in range(num_leaning):
                self.i_layer=self.data[i][1:]
                self.foward()
                self.back_propagation(self.data[i][0])
                print self.error(self.data[i][0])
        self.i_layer=self.data[149][1:]
        self.foward()
        print self.o_layer

if __name__=="__main__":
    tmpclass=newral_network()
    tmpclass.read_file("iris.data")
    tmpclass.main()

結果、

[ 0.0482195   0.07456389  0.92151814]

この場合、クラス3に分類されることを意味し、NNの正当性が確かめられました。

iris.data

0 5.1 3.5 1.4 0.2
0 4.9 3.0 1.4 0.2
0 4.7 3.2 1.3 0.2
0 4.6 3.1 1.5 0.2
0 5.0 3.6 1.4 0.2
0 5.4 3.9 1.7 0.4
0 4.6 3.4 1.4 0.3
0 5.0 3.4 1.5 0.2
0 4.4 2.9 1.4 0.2
0 4.9 3.1 1.5 0.1
0 5.4 3.7 1.5 0.2
0 4.8 3.4 1.6 0.2
0 4.8 3.0 1.4 0.1
0 4.3 3.0 1.1 0.1
0 5.8 4.0 1.2 0.2
0 5.7 4.4 1.5 0.4
0 5.4 3.9 1.3 0.4
0 5.1 3.5 1.4 0.3
0 5.7 3.8 1.7 0.3
0 5.1 3.8 1.5 0.3
0 5.4 3.4 1.7 0.2
0 5.1 3.7 1.5 0.4
0 4.6 3.6 1.0 0.2
0 5.1 3.3 1.7 0.5
0 4.8 3.4 1.9 0.2
0 5.0 3.0 1.6 0.2
0 5.0 3.4 1.6 0.4
0 5.2 3.5 1.5 0.2
0 5.2 3.4 1.4 0.2
0 4.7 3.2 1.6 0.2
0 4.8 3.1 1.6 0.2
0 5.4 3.4 1.5 0.4
0 5.2 4.1 1.5 0.1
0 5.5 4.2 1.4 0.2
0 4.9 3.1 1.5 0.2
0 5.0 3.2 1.2 0.2
0 5.5 3.5 1.3 0.2
0 4.9 3.6 1.4 0.1
0 4.4 3.0 1.3 0.2
0 5.1 3.4 1.5 0.2
0 5.0 3.5 1.3 0.3
0 4.5 2.3 1.3 0.3
0 4.4 3.2 1.3 0.2
0 5.0 3.5 1.6 0.6
0 5.1 3.8 1.9 0.4
0 4.8 3.0 1.4 0.3
0 5.1 3.8 1.6 0.2
0 4.6 3.2 1.4 0.2
0 5.3 3.7 1.5 0.2
0 5.0 3.3 1.4 0.2
1 7.0 3.2 4.7 1.4
1 6.4 3.2 4.5 1.5
1 6.9 3.1 4.9 1.5
1 5.5 2.3 4.0 1.3
1 6.5 2.8 4.6 1.5
1 5.7 2.8 4.5 1.3
1 6.3 3.3 4.7 1.6
1 4.9 2.4 3.3 1.0
1 6.6 2.9 4.6 1.3
1 5.2 2.7 3.9 1.4
1 5.0 2.0 3.5 1.0
1 5.9 3.0 4.2 1.5
1 6.0 2.2 4.0 1.0
1 6.1 2.9 4.7 1.4
1 5.6 2.9 3.6 1.3
1 6.7 3.1 4.4 1.4
1 5.6 3.0 4.5 1.5
1 5.8 2.7 4.1 1.0
1 6.2 2.2 4.5 1.5
1 5.6 2.5 3.9 1.1
1 5.9 3.2 4.8 1.8
1 6.1 2.8 4.0 1.3
1 6.3 2.5 4.9 1.5
1 6.1 2.8 4.7 1.2
1 6.4 2.9 4.3 1.3
1 6.6 3.0 4.4 1.4
1 6.8 2.8 4.8 1.4
1 6.7 3.0 5.0 1.7
1 6.0 2.9 4.5 1.5
1 5.7 2.6 3.5 1.0
1 5.5 2.4 3.8 1.1
1 5.5 2.4 3.7 1.0
1 5.8 2.7 3.9 1.2
1 6.0 2.7 5.1 1.6
1 5.4 3.0 4.5 1.5
1 6.0 3.4 4.5 1.6
1 6.7 3.1 4.7 1.5
1 6.3 2.3 4.4 1.3
1 5.6 3.0 4.1 1.3
1 5.5 2.5 4.0 1.3
1 5.5 2.6 4.4 1.2
1 6.1 3.0 4.6 1.4
1 5.8 2.6 4.0 1.2
1 5.0 2.3 3.3 1.0
1 5.6 2.7 4.2 1.3
1 5.7 3.0 4.2 1.2
1 5.7 2.9 4.2 1.3
1 6.2 2.9 4.3 1.3
1 5.1 2.5 3.0 1.1
1 5.7 2.8 4.1 1.3
2 6.3 3.3 6.0 2.5
2 5.8 2.7 5.1 1.9
2 7.1 3.0 5.9 2.1
2 6.3 2.9 5.6 1.8
2 6.5 3.0 5.8 2.2
2 7.6 3.0 6.6 2.1
2 4.9 2.5 4.5 1.7
2 7.3 2.9 6.3 1.8
2 6.7 2.5 5.8 1.8
2 7.2 3.6 6.1 2.5
2 6.5 3.2 5.1 2.0
2 6.4 2.7 5.3 1.9
2 6.8 3.0 5.5 2.1
2 5.7 2.5 5.0 2.0
2 5.8 2.8 5.1 2.4
2 6.4 3.2 5.3 2.3
2 6.5 3.0 5.5 1.8
2 7.7 3.8 6.7 2.2
2 7.7 2.6 6.9 2.3
2 6.0 2.2 5.0 1.5
2 6.9 3.2 5.7 2.3
2 5.6 2.8 4.9 2.0
2 7.7 2.8 6.7 2.0
2 6.3 2.7 4.9 1.8
2 6.7 3.3 5.7 2.1
2 7.2 3.2 6.0 1.8
2 6.2 2.8 4.8 1.8
2 6.1 3.0 4.9 1.8
2 6.4 2.8 5.6 2.1
2 7.2 3.0 5.8 1.6
2 7.4 2.8 6.1 1.9
2 7.9 3.8 6.4 2.0
2 6.4 2.8 5.6 2.2
2 6.3 2.8 5.1 1.5
2 6.1 2.6 5.6 1.4
2 7.7 3.0 6.1 2.3
2 6.3 3.4 5.6 2.4
2 6.4 3.1 5.5 1.8
2 6.0 3.0 4.8 1.8
2 6.9 3.1 5.4 2.1
2 6.7 3.1 5.6 2.4
2 6.9 3.1 5.1 2.3
2 5.8 2.7 5.1 1.9
2 6.8 3.2 5.9 2.3
2 6.7 3.3 5.7 2.5
2 6.7 3.0 5.2 2.3
2 6.3 2.5 5.0 1.9
2 6.5 3.0 5.2 2.0
2 6.2 3.4 5.4 2.3
2 5.9 3.0 5.1 1.8