ニューラルネット(Python)
Newral network(NN)を組んでみました。統計学や機械学習の分野で有名なiris.dataを用い、NNの正当性を評価しました。iris.dataは150行からなり、1行ずつに、一つのアヤメの分類と特徴が書いてあります。アヤメは3種類に分類され、4つの特徴があります。行の最初に、分類される番号、そして次に4つの特徴が書いてあります。このプログラムでは、149個のアヤメのデータを使って、残り1つのアヤメを分類します。
#! /usr/bin/env python import numpy as np import scipy.spatial.distance as dist class newral_network: def __init__(self): self.data=[] self.w_first=[] self.w_second=[] self.i_layer=[] self.h_layer=[] self.o_layer=[] self.teacher=[] self.eta=0.7 self.num_chara=0 self.num_class=0 def read_file(self,readfile): self.data=np.loadtxt(readfile) self.num_chara=int(len(self.data[0])-1) self.num_class=int(self.data[-1][0]+1) #self.num_class=5 self.w_first=np.ones((self.num_chara,self.num_chara)) self.w_second=np.ones((self.num_chara,self.num_class)) self.t_mat=np.identity(self.num_class) print "the number of charactor and class:",self.num_chara,",",self.num_class def activation_func(self,x): return 1/(1+np.exp(-x)) def back_propagation(self,ref_num): deltak=np.zeros(self.num_class) dw_first=np.zeros((self.num_chara,self.num_chara)) dw_second=np.zeros((self.num_chara,self.num_class)) for k in range(self.num_class): deltak[k]=(self.t_mat[ref_num,k]-self.o_layer[k])*self.o_layer[k]*(1-self.o_layer[k]) for j in range(self.num_chara): dw_second[j,k]=self.eta*self.h_layer[j]*deltak[k] self.w_second+=dw_second for j in range(self.num_chara): deltaj=self.h_layer[j]*(1-self.h_layer[j])*np.dot(self.w_second[j],deltak) for i in range(self.num_chara): dw_first[i,j]=self.eta*self.i_layer[i]*deltaj self.w_first+=dw_first def foward(self): self.h_layer=self.activation_func(np.dot(self.i_layer,self.w_first)) self.o_layer=self.activation_func(np.dot(self.h_layer,self.w_second)) def error(self,t_num): return dist.pdist((self.o_layer,self.t_mat[t_num]),'euclidean')/2.0 def main(self): num_leaning=149 for j in range(4): for i in range(num_leaning): self.i_layer=self.data[i][1:] self.foward() self.back_propagation(self.data[i][0]) print self.error(self.data[i][0]) self.i_layer=self.data[149][1:] self.foward() print self.o_layer if __name__=="__main__": tmpclass=newral_network() tmpclass.read_file("iris.data") tmpclass.main()
結果、
[ 0.0482195 0.07456389 0.92151814]
この場合、クラス3に分類されることを意味し、NNの正当性が確かめられました。
iris.data
0 5.1 3.5 1.4 0.2 0 4.9 3.0 1.4 0.2 0 4.7 3.2 1.3 0.2 0 4.6 3.1 1.5 0.2 0 5.0 3.6 1.4 0.2 0 5.4 3.9 1.7 0.4 0 4.6 3.4 1.4 0.3 0 5.0 3.4 1.5 0.2 0 4.4 2.9 1.4 0.2 0 4.9 3.1 1.5 0.1 0 5.4 3.7 1.5 0.2 0 4.8 3.4 1.6 0.2 0 4.8 3.0 1.4 0.1 0 4.3 3.0 1.1 0.1 0 5.8 4.0 1.2 0.2 0 5.7 4.4 1.5 0.4 0 5.4 3.9 1.3 0.4 0 5.1 3.5 1.4 0.3 0 5.7 3.8 1.7 0.3 0 5.1 3.8 1.5 0.3 0 5.4 3.4 1.7 0.2 0 5.1 3.7 1.5 0.4 0 4.6 3.6 1.0 0.2 0 5.1 3.3 1.7 0.5 0 4.8 3.4 1.9 0.2 0 5.0 3.0 1.6 0.2 0 5.0 3.4 1.6 0.4 0 5.2 3.5 1.5 0.2 0 5.2 3.4 1.4 0.2 0 4.7 3.2 1.6 0.2 0 4.8 3.1 1.6 0.2 0 5.4 3.4 1.5 0.4 0 5.2 4.1 1.5 0.1 0 5.5 4.2 1.4 0.2 0 4.9 3.1 1.5 0.2 0 5.0 3.2 1.2 0.2 0 5.5 3.5 1.3 0.2 0 4.9 3.6 1.4 0.1 0 4.4 3.0 1.3 0.2 0 5.1 3.4 1.5 0.2 0 5.0 3.5 1.3 0.3 0 4.5 2.3 1.3 0.3 0 4.4 3.2 1.3 0.2 0 5.0 3.5 1.6 0.6 0 5.1 3.8 1.9 0.4 0 4.8 3.0 1.4 0.3 0 5.1 3.8 1.6 0.2 0 4.6 3.2 1.4 0.2 0 5.3 3.7 1.5 0.2 0 5.0 3.3 1.4 0.2 1 7.0 3.2 4.7 1.4 1 6.4 3.2 4.5 1.5 1 6.9 3.1 4.9 1.5 1 5.5 2.3 4.0 1.3 1 6.5 2.8 4.6 1.5 1 5.7 2.8 4.5 1.3 1 6.3 3.3 4.7 1.6 1 4.9 2.4 3.3 1.0 1 6.6 2.9 4.6 1.3 1 5.2 2.7 3.9 1.4 1 5.0 2.0 3.5 1.0 1 5.9 3.0 4.2 1.5 1 6.0 2.2 4.0 1.0 1 6.1 2.9 4.7 1.4 1 5.6 2.9 3.6 1.3 1 6.7 3.1 4.4 1.4 1 5.6 3.0 4.5 1.5 1 5.8 2.7 4.1 1.0 1 6.2 2.2 4.5 1.5 1 5.6 2.5 3.9 1.1 1 5.9 3.2 4.8 1.8 1 6.1 2.8 4.0 1.3 1 6.3 2.5 4.9 1.5 1 6.1 2.8 4.7 1.2 1 6.4 2.9 4.3 1.3 1 6.6 3.0 4.4 1.4 1 6.8 2.8 4.8 1.4 1 6.7 3.0 5.0 1.7 1 6.0 2.9 4.5 1.5 1 5.7 2.6 3.5 1.0 1 5.5 2.4 3.8 1.1 1 5.5 2.4 3.7 1.0 1 5.8 2.7 3.9 1.2 1 6.0 2.7 5.1 1.6 1 5.4 3.0 4.5 1.5 1 6.0 3.4 4.5 1.6 1 6.7 3.1 4.7 1.5 1 6.3 2.3 4.4 1.3 1 5.6 3.0 4.1 1.3 1 5.5 2.5 4.0 1.3 1 5.5 2.6 4.4 1.2 1 6.1 3.0 4.6 1.4 1 5.8 2.6 4.0 1.2 1 5.0 2.3 3.3 1.0 1 5.6 2.7 4.2 1.3 1 5.7 3.0 4.2 1.2 1 5.7 2.9 4.2 1.3 1 6.2 2.9 4.3 1.3 1 5.1 2.5 3.0 1.1 1 5.7 2.8 4.1 1.3 2 6.3 3.3 6.0 2.5 2 5.8 2.7 5.1 1.9 2 7.1 3.0 5.9 2.1 2 6.3 2.9 5.6 1.8 2 6.5 3.0 5.8 2.2 2 7.6 3.0 6.6 2.1 2 4.9 2.5 4.5 1.7 2 7.3 2.9 6.3 1.8 2 6.7 2.5 5.8 1.8 2 7.2 3.6 6.1 2.5 2 6.5 3.2 5.1 2.0 2 6.4 2.7 5.3 1.9 2 6.8 3.0 5.5 2.1 2 5.7 2.5 5.0 2.0 2 5.8 2.8 5.1 2.4 2 6.4 3.2 5.3 2.3 2 6.5 3.0 5.5 1.8 2 7.7 3.8 6.7 2.2 2 7.7 2.6 6.9 2.3 2 6.0 2.2 5.0 1.5 2 6.9 3.2 5.7 2.3 2 5.6 2.8 4.9 2.0 2 7.7 2.8 6.7 2.0 2 6.3 2.7 4.9 1.8 2 6.7 3.3 5.7 2.1 2 7.2 3.2 6.0 1.8 2 6.2 2.8 4.8 1.8 2 6.1 3.0 4.9 1.8 2 6.4 2.8 5.6 2.1 2 7.2 3.0 5.8 1.6 2 7.4 2.8 6.1 1.9 2 7.9 3.8 6.4 2.0 2 6.4 2.8 5.6 2.2 2 6.3 2.8 5.1 1.5 2 6.1 2.6 5.6 1.4 2 7.7 3.0 6.1 2.3 2 6.3 3.4 5.6 2.4 2 6.4 3.1 5.5 1.8 2 6.0 3.0 4.8 1.8 2 6.9 3.1 5.4 2.1 2 6.7 3.1 5.6 2.4 2 6.9 3.1 5.1 2.3 2 5.8 2.7 5.1 1.9 2 6.8 3.2 5.9 2.3 2 6.7 3.3 5.7 2.5 2 6.7 3.0 5.2 2.3 2 6.3 2.5 5.0 1.9 2 6.5 3.0 5.2 2.0 2 6.2 3.4 5.4 2.3 2 5.9 3.0 5.1 1.8