티스토리 뷰
손글씨 숫자 인식
손글씨를 신경망 구조를 활용해 0-9까지의 숫자중 어떤 값을 나타내는지 분류하는 과정 중 이미 학습된 매개변수를 사용해 추론 과정을 알아보자!
https://github.com/sujin-create/deep-learning-from-scratch.git
GitHub - sujin-create/deep-learning-from-scratch: 『밑바닥부터 시작하는 딥러닝』(한빛미디어, 2017)
『밑바닥부터 시작하는 딥러닝』(한빛미디어, 2017). Contribute to sujin-create/deep-learning-from-scratch development by creating an account on GitHub.
github.com
제공하는 코드를 fork한뒤 코드를 이해하며 돌려보자!
코드를 각 기능별로 정리할 것이기에 정확한 모든 코드를 볼 것이라면 위의 깃허브를 활용!
1. 우선 dataset폴더 안의 mnist.py중 load_mnist를 통해서 mnist 데이터를 읽어옵니다.
세팅 : ch03으로 위치를 작업 환경을 옮겨 부모 디렉터리 파일 dataset에 접근하기 위한 코드를 작성합니다. 이후 common파일도 마찬가지.
import sys, os
sys.path.append(os.pardir)
읽어오기 : load_mnist => MNIST 데이터를 (training set의 입력값, training set의 실제값) , (testset의 입력값, testset의 실제값)으로 데이터들을 반환합니다.
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
<< bool type의 세가지 옵션 >>
1. normalize
=> 정규화(값을 0-1사이의 값으로 설정해줌)시켜주는 옵션.
2. flatten
=> 행렬 공부시 봤겠지만 평탄화 시키는 것으로 1차원화 시켜주는 옵션.
3. one_hot_label
=> 정답을 뜻하는 원소만 1로 저장 나머지는 다 0으로 저장하는 옵션.
2. 다음으로 불러온 데이터를 확인해봅시다!
# 이미지기능과 관련된 코드만 살펴봅시다
import numpy as np
from PIL import Image
def img_show(img):
pil_img = Image.fromarray(np.uint8(img))#PIL에서 제공하는 Image.fromarray를 사용해 array로 저장된 이미지값을 시각화할 수 있게끔함.
pil_img.show()
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
#1차원배열로 평탄화 시키며 학습된 데이터들을 불러오며 정규화는 하지 않음.
img = x_train[0]
label = t_train[0]
print(label) # 5
print(img.shape) # (784,)
img = img.reshape(28, 28) # 형상을 원래 이미지의 크기로 변형
#PIL의 Image를 사용하려면 원래 크기를 사용해야한다는 점에 주의하자
# 원래 크기를 사용하기 위해서는 [].reshape를 활용한다.
print(img.shape) # (28, 28)
img_show(img)
* PIL에서 Image.fromarray(np.uint8(img))
=> 배열명.reshape()함수를 활용해 평탄화 시켰던 784크기의 배열을 28*28의 2차원 배열로 늘려줘야한다.
=> 이때 어차피 배열을 2차원으로 늘릴거면 평탄화를 왜시키지?라는 생각이 순간 들었지만 생각해보니 그냥 다음 코드는 이미지화를 위해 필요한코드로 실제 기능을 의미하는 것이 아니라는 것이다.
5
(784,)
(28, 28)
코드를 돌려봤을 때 shape의 값으로 다음과같이 나온다.
5라는 label값이 아래의 손글씨를 의미한다.

라이브러리 및 함수를 적절히 import하고 데이터를 불러온후 이미지 데이터를 확인한 결과 다음과 같은 그림을 확인할 수 있었다.
3. 다음으로는 신경망의 추론처리 과정에 필요한 세 함수를 정의하자!
1. get_data()
위에서 알아본 데이터를 불러오는 코드가 정의된 함수.
2. init_network() : 학습데이터를 불러오는 함수
이전 포스팅에서 알아본 신경망의 매개변수(W와 B)를 정의하는 함수.
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
<<pickle>>
pickle이라는 기능을 활용. 데이터셋을 2번이상 읽어올 필요없이 순식간에 읽어오게 도움을 주는 기능을 활용.
pickle파일을 open f로 불러오고 network에 매개변수 데이터를 저장한다. 이때, network는 딕셔너리 형태로 데이터를 저장중.
3. predict(network,x)
get_data에서 얻은 데이터 중 학습데이터 입력값을 init_network()에서 얻은 매개변수 x로 넣어주고 가중치와 편향을 도입시켜 학습데이터 예측치를 반환하는 함수.
=> 해당 코드의 이해는 이전 포스팅을 보면 가능.
def predict(network, x):
w1, w2, w3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, w1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, w2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, w3) + b3
y = softmax(a3)
return y
* 세가지 함수를 활용해 각 데이터들의 정보가 저장된 784의 사이즈를 가지는 배열을 x값으로 매개변수를 network로 넣어주며 predict함수를 호출해도 되지만, batch_size를 정하는 것이 효율적.
=> 대부분의 수치 계산 라이브러리 큰 배열을 효율적으로 처리할 수 있도록 설정이 되어있기 때문이다.
=> np.argmax()를 사용해 확률이 가장 높은 원소의 인덱스를 얻어 그 값을 예측값으로 지정해준다.
#세가지 함수를 활용해 코드를 돌려줌
x, t = get_data()
network = init_network()
batch_size = 100 # 배치 크기
accuracy_cnt = 0
for i in range(0, len(x), batch_size):
x_batch = x[i:i+batch_size]
y_batch = predict(network, x_batch)
p = np.argmax(y_batch, axis=1)
accuracy_cnt += np.sum(p == t[i:i+batch_size])
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
$ python neuralnet_mnist_batch.py
Accuracy:0.9352
<<결론>>
퍼셉트론과 다르게 신경망은 sigmoid함수를 불러와 활성화 함수로 사용하였고 이 차이가 신경망 학습에서 중요한 역할을 한다는 것을 기억하자. 순전파로 진행되는 MNIST모델을 test해보는 과정을 보며 network()를 어떻게 test데이터셋을 활용해 예측값을 도출하고 정확도를 확인할 수 있는지 딥러닝 모델을 이해해봄. 배치처리를 활용하면 훨씬 빠르게 결과를 얻을 수 있따는 것을 기억하자.
'데이터분석 및 인공지능 > 밑바닥부터 시작하는 딥러닝' 카테고리의 다른 글
밑바닥부터 시작하는 딥러닝-(4) 신경망 (0) | 2021.09.23 |
---|---|
밑바닥부터 시작하는 딥러닝-(3) (0) | 2021.09.20 |
밑바닥부터 시작하는 딥러닝-(2) (0) | 2021.09.17 |
밑바닥부터 시작하는 딥러닝 -(1) (0) | 2021.09.17 |
- Total
- Today
- Yesterday
- 기본 텍스트 분류
- LAMBDA
- stack 컨테이너
- 백준 4963
- 온라인프로필 만들기
- c++덱
- 백트래킹(1)
- 백준 숫자놀이
- 파이썬 알아두면 유용
- 기사작성 대외활동
- 스택 파이썬
- mm1queue
- 11053 백준
- 소프트웨어공학설계
- 4963 섬의개수
- 10866 백준
- 코딩월드뉴스
- 딥러닝입문
- 백준 15650 파이썬
- 백준 11053 파이썬
- 모듈 사용법
- 백준 10866
- 영화 리뷰 긍정 부정 분류
- DRF 회원관리
- 효율적인방법찾기
- CSMA/CD란?
- 시뮬레이션 c
- 핀테크 트렌드
- 13886
- CREATE ASSERTION
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 |