반응형
Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
Tags
- tensorflow
- web
- mglearn
- CES 2O21 참가
- web 개발
- classification
- html
- cudnn
- 자료구조
- web 사진
- KNeighborsClassifier
- bccard
- 웹 용어
- web 용어
- discrete_scatter
- paragraph
- Keras
- broscoding
- 데이터전문기관
- vscode
- 재귀함수
- 머신러닝
- 대이터
- 결합전문기관
- java역사
- inorder
- C언어
- postorder
- pycharm
- CES 2O21 참여
Archives
- Today
- Total
bro's coding
중간층 만들기(분류) 본문
반응형
data : iris
X : iris.data
y : iris.target
w : random.randn
중간층의 활성함수 : sigmoid
- 중간층 뉴런 10개를 포함해 y 값 분류하기.
import numpy as np
from sklearn.datasets import load_iris
# data 준비
iris=load_iris()
X=iris.data
y=iris.target
w=np.random.randn(4,10)
b=np.random.randn(10)
def sigmoid(t):
return 1/(1+np.exp(-t))
u=sigmoid(X@w+b)
ww=np.random.randn(10,3)
bb=np.random.randn(3)
def softmax(t):
return np.exp(t)/np.sum(np.exp(t),axis=1).reshape(-1,1)
pred_y=softmax(u@ww+bb)
display(pred_y)
array([[0.0946966 , 0.15820583, 0.74709757],
[0.08461151, 0.13753592, 0.77785257],
[0.09187484, 0.15265191, 0.75547325],
[0.08750447, 0.15223665, 0.76025887],
[0.09658978, 0.16489912, 0.7385111 ],
[0.09683732, 0.17517545, 0.72798723],
[0.09430753, 0.16681925, 0.73887323],
[0.09180266, 0.15711102, 0.75108632],
[0.08554709, 0.14606362, 0.76838929],
[0.08481409, 0.14141126, 0.77377466],
[0.09588453, 0.16129611, 0.74281935],
[0.09127937, 0.16284669, 0.74587394],
[0.08448227, 0.13772764, 0.77779009],
[0.09194439, 0.14820171, 0.7598539 ],
[0.10389598, 0.15986338, 0.73624065],
[0.10355146, 0.1828259 , 0.71362264],
[0.1015654 , 0.16981772, 0.72861688],
[0.09506259, 0.16025495, 0.74468246],
[0.09437382, 0.16298609, 0.74264009],
[0.09796686, 0.17338912, 0.72864402],
[0.08664464, 0.14818985, 0.7651655 ],
[0.09678975, 0.17128509, 0.73192515],
[0.10153156, 0.1665273 , 0.73194114],
[0.088152 , 0.15818986, 0.75365814],
[0.08770561, 0.16468911, 0.74760528],
[0.08079425, 0.13617472, 0.78303103],
[0.0912909 , 0.16227999, 0.74642911],
[0.09291534, 0.15690741, 0.75017725],
[0.09229029, 0.15076004, 0.75694968],
[0.0877941 , 0.15560089, 0.75660501],
[0.08493002, 0.14767432, 0.76739565],
[0.09107647, 0.15137706, 0.75754648],
[0.10112064, 0.1788044 , 0.72007495],
[0.10330346, 0.17661856, 0.72007798],
[0.08481409, 0.14141126, 0.77377466],
[0.09199804, 0.14329198, 0.76470997],
[0.09443527, 0.14624258, 0.75932214],
[0.08481409, 0.14141126, 0.77377466],
[0.08910151, 0.15032915, 0.76056935],
[0.09134322, 0.15455869, 0.75409809],
[0.09661309, 0.16117319, 0.74221371],
[0.07125992, 0.11029176, 0.81844831],
[0.09278196, 0.15978939, 0.74742864],
[0.09307416, 0.1698198 , 0.73710604],
[0.09364276, 0.17851572, 0.72784152],
[0.08611071, 0.14308734, 0.77080195],
[0.09678861, 0.1730636 , 0.73014779],
[0.09087866, 0.15627705, 0.75284429],
[0.09618947, 0.16366132, 0.7401492 ],
[0.09115483, 0.15103548, 0.75780969],
[0.04795394, 0.10404915, 0.8479969 ],
[0.05472044, 0.12161601, 0.82366356],
[0.0466784 , 0.09820204, 0.85511957],
[0.0488499 , 0.08787753, 0.86327257],
[0.0464027 , 0.09417933, 0.85941797],
[0.05285514, 0.1008173 , 0.84632756],
[0.05703639, 0.1279034 , 0.8150602 ],
[0.05645185, 0.11127117, 0.83227697],
[0.04622812, 0.0933376 , 0.86043428],
[0.05855797, 0.12274892, 0.8186931 ],
[0.04966144, 0.08162653, 0.86871203],
[0.05684349, 0.12541015, 0.81774635],
[0.04421547, 0.07308939, 0.88269514],
[0.05049237, 0.09916586, 0.85034176],
[0.06023855, 0.13221004, 0.80755141],
[0.05013542, 0.10951931, 0.84034527],
[0.05825237, 0.12209882, 0.8196488 ],
[0.04999684, 0.09632243, 0.85368073],
[0.04391092, 0.07081962, 0.88526946],
[0.04983704, 0.0951333 , 0.85502966],
[0.05923867, 0.1313591 , 0.80940223],
[0.05108432, 0.10823396, 0.84068172],
[0.04500784, 0.07273054, 0.88226162],
[0.04842929, 0.08772317, 0.86384754],
[0.04887543, 0.10267318, 0.84845139],
[0.04907709, 0.1056074 , 0.84531551],
[0.04366809, 0.08426389, 0.87206802],
[0.0472566 , 0.09778439, 0.85495901],
[0.05251329, 0.10932271, 0.83816401],
[0.05171251, 0.10526662, 0.84302086],
[0.0495982 , 0.09322235, 0.85717946],
[0.04942711, 0.09296316, 0.85760973],
[0.05176631, 0.10672396, 0.84150973],
[0.04946777, 0.08094933, 0.86958291],
[0.06037262, 0.1250462 , 0.81458118],
[0.06280554, 0.14461318, 0.79258128],
[0.04908697, 0.10570248, 0.84521056],
[0.04344865, 0.07287174, 0.88367961],
[0.0590129 , 0.12732653, 0.81366057],
[0.05135749, 0.09939739, 0.84924512],
[0.05182 , 0.09047857, 0.85770143],
[0.05244212, 0.108406 , 0.83915188],
[0.04971666, 0.09854295, 0.85174039],
[0.05371682, 0.10343661, 0.84284658],
[0.05296526, 0.10481056, 0.84222419],
[0.05696018, 0.11986909, 0.82317073],
[0.05555587, 0.11631121, 0.82813292],
[0.05052354, 0.10570875, 0.8437677 ],
[0.05908495, 0.12281577, 0.81809928],
[0.05412585, 0.11251296, 0.83336118],
[0.05892522, 0.11464751, 0.82642727],
[0.05277266, 0.09150988, 0.85571746],
[0.04438705, 0.07842599, 0.87718696],
[0.0497325 , 0.07947393, 0.87079357],
[0.050514 , 0.08902168, 0.86046433],
[0.04237111, 0.05755216, 0.90007673],
[0.06021823, 0.10219561, 0.83758616],
[0.04336227, 0.05643004, 0.90020769],
[0.04477436, 0.0547636 , 0.90046204],
[0.05318961, 0.12039643, 0.82641396],
[0.05349388, 0.11918407, 0.82732206],
[0.04672981, 0.0800948 , 0.87317538],
[0.04700159, 0.0932671 , 0.85973131],
[0.05199184, 0.08653615, 0.86147201],
[0.0568307 , 0.11580197, 0.82736733],
[0.05543255, 0.12441096, 0.8201565 ],
[0.04869655, 0.08733792, 0.86396553],
[0.04866142, 0.09801144, 0.85332714],
[0.04273827, 0.0438471 , 0.91341463],
[0.04651292, 0.05856649, 0.89492059],
[0.04949006, 0.10435044, 0.84615951],
[0.05687823, 0.11120644, 0.83191532],
[0.04244046, 0.0481131 , 0.90944644],
[0.04726718, 0.09053282, 0.8622 ],
[0.05200276, 0.107151 , 0.84084623],
[0.04482672, 0.07800224, 0.87717104],
[0.04961861, 0.10058484, 0.84979655],
[0.05346437, 0.11170703, 0.83482861],
[0.04869773, 0.08151788, 0.86978439],
[0.04311647, 0.06980073, 0.8870828 ],
[0.04197187, 0.0593613 , 0.89866683],
[0.04667886, 0.0993745 , 0.85394664],
[0.04911126, 0.08450447, 0.86638427],
[0.04741521, 0.08193148, 0.8706533 ],
[0.04873098, 0.0568308 , 0.89443822],
[0.04065922, 0.07444947, 0.88489131],
[0.05946292, 0.13283009, 0.80770699],
[0.05100133, 0.09457131, 0.85442736],
[0.0547551 , 0.11632282, 0.82892208],
[0.04768056, 0.10120087, 0.85111857],
[0.0507378 , 0.10698793, 0.84227428],
[0.04993986, 0.11556938, 0.83449075],
[0.05277266, 0.09150988, 0.85571746],
[0.05030188, 0.09878192, 0.8509162 ],
[0.05425893, 0.12050797, 0.8252331 ],
[0.04958551, 0.10985085, 0.84056364],
[0.04570257, 0.07945324, 0.87484419],
[0.04982687, 0.10313901, 0.84703412],
[0.06024747, 0.13738656, 0.80236597],
[0.05525639, 0.10751219, 0.83723141]])
onehot_y=np.eye(3)[y]
array([[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[1., 0., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 1., 0.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.],
[0., 0., 1.]])
# 이렇게 하면 안됨! 분모에 문제가 생김
# cross_entropy=(-(onehot_y*np.log(pred_y+1e-7))).mean(axis=0)
# 전체에 대한 cross entropy error을 알고 싶으면
cross_entropy = -(onehot_y*np.log(pred_y+1e-7)).sum()/len(X)
# len(X) : 150
# 각각에 대한 cross entropy error을 알고 싶으면
cross_entropies=[-np.log(pred_y[np.where(y==i)]+1e-7).mean() for i in range(3)]
cross_entropy,cross_entropies
(1.606136586854307,
[1.5104574916079385, 1.8035330164058632, 1.857196530019606])
반응형
'[AI] > python.Neural_Network' 카테고리의 다른 글
다중 분류 원리 (0) | 2020.05.08 |
---|---|
경사하강법 (0) | 2020.05.08 |
중간층 만들기(값예측) (0) | 2020.05.07 |
activation function(활성 함수) (0) | 2020.05.07 |
neural_network.logisticRegression.원리알기 (0) | 2020.05.06 |
Comments