반응형
Notice
Recent Posts
Recent Comments
Link
관리 메뉴

bro's coding

중간층 만들기(분류) 본문

[AI]/python.Neural_Network

중간층 만들기(분류)

givemebro 2020. 5. 7. 15:31
반응형
data : iris
X : iris.data
y : iris.target
w : random.randn
중간층의 활성함수 : sigmoid

- 중간층 뉴런 10개를 포함해 y 값 분류하기.

 

import numpy as np
from sklearn.datasets import load_iris

# data 준비
iris=load_iris()
X=iris.data
y=iris.target

w=np.random.randn(4,10)
b=np.random.randn(10)

def sigmoid(t):
    return 1/(1+np.exp(-t))

u=sigmoid(X@w+b)

ww=np.random.randn(10,3)
bb=np.random.randn(3)

def softmax(t):
    return np.exp(t)/np.sum(np.exp(t),axis=1).reshape(-1,1)

pred_y=softmax(u@ww+bb)

display(pred_y)

 

array([[0.0946966 , 0.15820583, 0.74709757],
       [0.08461151, 0.13753592, 0.77785257],
       [0.09187484, 0.15265191, 0.75547325],
       [0.08750447, 0.15223665, 0.76025887],
       [0.09658978, 0.16489912, 0.7385111 ],
       [0.09683732, 0.17517545, 0.72798723],
       [0.09430753, 0.16681925, 0.73887323],
       [0.09180266, 0.15711102, 0.75108632],
       [0.08554709, 0.14606362, 0.76838929],
       [0.08481409, 0.14141126, 0.77377466],
       [0.09588453, 0.16129611, 0.74281935],
       [0.09127937, 0.16284669, 0.74587394],
       [0.08448227, 0.13772764, 0.77779009],
       [0.09194439, 0.14820171, 0.7598539 ],
       [0.10389598, 0.15986338, 0.73624065],
       [0.10355146, 0.1828259 , 0.71362264],
       [0.1015654 , 0.16981772, 0.72861688],
       [0.09506259, 0.16025495, 0.74468246],
       [0.09437382, 0.16298609, 0.74264009],
       [0.09796686, 0.17338912, 0.72864402],
       [0.08664464, 0.14818985, 0.7651655 ],
       [0.09678975, 0.17128509, 0.73192515],
       [0.10153156, 0.1665273 , 0.73194114],
       [0.088152  , 0.15818986, 0.75365814],
       [0.08770561, 0.16468911, 0.74760528],
       [0.08079425, 0.13617472, 0.78303103],
       [0.0912909 , 0.16227999, 0.74642911],
       [0.09291534, 0.15690741, 0.75017725],
       [0.09229029, 0.15076004, 0.75694968],
       [0.0877941 , 0.15560089, 0.75660501],
       [0.08493002, 0.14767432, 0.76739565],
       [0.09107647, 0.15137706, 0.75754648],
       [0.10112064, 0.1788044 , 0.72007495],
       [0.10330346, 0.17661856, 0.72007798],
       [0.08481409, 0.14141126, 0.77377466],
       [0.09199804, 0.14329198, 0.76470997],
       [0.09443527, 0.14624258, 0.75932214],
       [0.08481409, 0.14141126, 0.77377466],
       [0.08910151, 0.15032915, 0.76056935],
       [0.09134322, 0.15455869, 0.75409809],
       [0.09661309, 0.16117319, 0.74221371],
       [0.07125992, 0.11029176, 0.81844831],
       [0.09278196, 0.15978939, 0.74742864],
       [0.09307416, 0.1698198 , 0.73710604],
       [0.09364276, 0.17851572, 0.72784152],
       [0.08611071, 0.14308734, 0.77080195],
       [0.09678861, 0.1730636 , 0.73014779],
       [0.09087866, 0.15627705, 0.75284429],
       [0.09618947, 0.16366132, 0.7401492 ],
       [0.09115483, 0.15103548, 0.75780969],
       [0.04795394, 0.10404915, 0.8479969 ],
       [0.05472044, 0.12161601, 0.82366356],
       [0.0466784 , 0.09820204, 0.85511957],
       [0.0488499 , 0.08787753, 0.86327257],
       [0.0464027 , 0.09417933, 0.85941797],
       [0.05285514, 0.1008173 , 0.84632756],
       [0.05703639, 0.1279034 , 0.8150602 ],
       [0.05645185, 0.11127117, 0.83227697],
       [0.04622812, 0.0933376 , 0.86043428],
       [0.05855797, 0.12274892, 0.8186931 ],
       [0.04966144, 0.08162653, 0.86871203],
       [0.05684349, 0.12541015, 0.81774635],
       [0.04421547, 0.07308939, 0.88269514],
       [0.05049237, 0.09916586, 0.85034176],
       [0.06023855, 0.13221004, 0.80755141],
       [0.05013542, 0.10951931, 0.84034527],
       [0.05825237, 0.12209882, 0.8196488 ],
       [0.04999684, 0.09632243, 0.85368073],
       [0.04391092, 0.07081962, 0.88526946],
       [0.04983704, 0.0951333 , 0.85502966],
       [0.05923867, 0.1313591 , 0.80940223],
       [0.05108432, 0.10823396, 0.84068172],
       [0.04500784, 0.07273054, 0.88226162],
       [0.04842929, 0.08772317, 0.86384754],
       [0.04887543, 0.10267318, 0.84845139],
       [0.04907709, 0.1056074 , 0.84531551],
       [0.04366809, 0.08426389, 0.87206802],
       [0.0472566 , 0.09778439, 0.85495901],
       [0.05251329, 0.10932271, 0.83816401],
       [0.05171251, 0.10526662, 0.84302086],
       [0.0495982 , 0.09322235, 0.85717946],
       [0.04942711, 0.09296316, 0.85760973],
       [0.05176631, 0.10672396, 0.84150973],
       [0.04946777, 0.08094933, 0.86958291],
       [0.06037262, 0.1250462 , 0.81458118],
       [0.06280554, 0.14461318, 0.79258128],
       [0.04908697, 0.10570248, 0.84521056],
       [0.04344865, 0.07287174, 0.88367961],
       [0.0590129 , 0.12732653, 0.81366057],
       [0.05135749, 0.09939739, 0.84924512],
       [0.05182   , 0.09047857, 0.85770143],
       [0.05244212, 0.108406  , 0.83915188],
       [0.04971666, 0.09854295, 0.85174039],
       [0.05371682, 0.10343661, 0.84284658],
       [0.05296526, 0.10481056, 0.84222419],
       [0.05696018, 0.11986909, 0.82317073],
       [0.05555587, 0.11631121, 0.82813292],
       [0.05052354, 0.10570875, 0.8437677 ],
       [0.05908495, 0.12281577, 0.81809928],
       [0.05412585, 0.11251296, 0.83336118],
       [0.05892522, 0.11464751, 0.82642727],
       [0.05277266, 0.09150988, 0.85571746],
       [0.04438705, 0.07842599, 0.87718696],
       [0.0497325 , 0.07947393, 0.87079357],
       [0.050514  , 0.08902168, 0.86046433],
       [0.04237111, 0.05755216, 0.90007673],
       [0.06021823, 0.10219561, 0.83758616],
       [0.04336227, 0.05643004, 0.90020769],
       [0.04477436, 0.0547636 , 0.90046204],
       [0.05318961, 0.12039643, 0.82641396],
       [0.05349388, 0.11918407, 0.82732206],
       [0.04672981, 0.0800948 , 0.87317538],
       [0.04700159, 0.0932671 , 0.85973131],
       [0.05199184, 0.08653615, 0.86147201],
       [0.0568307 , 0.11580197, 0.82736733],
       [0.05543255, 0.12441096, 0.8201565 ],
       [0.04869655, 0.08733792, 0.86396553],
       [0.04866142, 0.09801144, 0.85332714],
       [0.04273827, 0.0438471 , 0.91341463],
       [0.04651292, 0.05856649, 0.89492059],
       [0.04949006, 0.10435044, 0.84615951],
       [0.05687823, 0.11120644, 0.83191532],
       [0.04244046, 0.0481131 , 0.90944644],
       [0.04726718, 0.09053282, 0.8622    ],
       [0.05200276, 0.107151  , 0.84084623],
       [0.04482672, 0.07800224, 0.87717104],
       [0.04961861, 0.10058484, 0.84979655],
       [0.05346437, 0.11170703, 0.83482861],
       [0.04869773, 0.08151788, 0.86978439],
       [0.04311647, 0.06980073, 0.8870828 ],
       [0.04197187, 0.0593613 , 0.89866683],
       [0.04667886, 0.0993745 , 0.85394664],
       [0.04911126, 0.08450447, 0.86638427],
       [0.04741521, 0.08193148, 0.8706533 ],
       [0.04873098, 0.0568308 , 0.89443822],
       [0.04065922, 0.07444947, 0.88489131],
       [0.05946292, 0.13283009, 0.80770699],
       [0.05100133, 0.09457131, 0.85442736],
       [0.0547551 , 0.11632282, 0.82892208],
       [0.04768056, 0.10120087, 0.85111857],
       [0.0507378 , 0.10698793, 0.84227428],
       [0.04993986, 0.11556938, 0.83449075],
       [0.05277266, 0.09150988, 0.85571746],
       [0.05030188, 0.09878192, 0.8509162 ],
       [0.05425893, 0.12050797, 0.8252331 ],
       [0.04958551, 0.10985085, 0.84056364],
       [0.04570257, 0.07945324, 0.87484419],
       [0.04982687, 0.10313901, 0.84703412],
       [0.06024747, 0.13738656, 0.80236597],
       [0.05525639, 0.10751219, 0.83723141]])

 

onehot_y=np.eye(3)[y]
array([[1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.]])

 

# 이렇게 하면 안됨! 분모에 문제가 생김
# cross_entropy=(-(onehot_y*np.log(pred_y+1e-7))).mean(axis=0)

# 전체에 대한 cross entropy error을 알고 싶으면

cross_entropy = -(onehot_y*np.log(pred_y+1e-7)).sum()/len(X)
# len(X) : 150

# 각각에 대한 cross entropy error을 알고 싶으면
cross_entropies=[-np.log(pred_y[np.where(y==i)]+1e-7).mean() for i in range(3)]

cross_entropy,cross_entropies
(1.606136586854307,
 [1.5104574916079385, 1.8035330164058632, 1.857196530019606])
반응형

'[AI] > python.Neural_Network' 카테고리의 다른 글

다중 분류 원리  (0) 2020.05.08
경사하강법  (0) 2020.05.08
중간층 만들기(값예측)  (0) 2020.05.07
activation function(활성 함수)  (0) 2020.05.07
neural_network.logisticRegression.원리알기  (0) 2020.05.06
Comments