撰寫第一支神經網路程式

利用 Keras 辨識 MNIST手寫數字資料集

Date: 2019/12/28    |    Lecturer: Chia

目次

  • 安裝深度學習套件

  • MNIST手寫數字資料集

  • 五個步驟

    1. 載入資料

    2. 建構神經網路架構

    3. 資料預處理

    4. 模型訓練

    5. 評估測試正確率

安裝深度學習套件

$ pip install keras
$ pip install tensorflow
  • Keras 可說是「最適合初學者」的深度學習套件!
    • 其底層的深度學習計算,可使用
      • TensorFlow
      • Theano

安裝深度學習套件

  • Keras 優點:
    • 內建常用的類神經網路元件。
    • 用最少的程式碼,建構複雜的深度學習網路架構。
  • Keras 缺點:
    • 為了同時與Theano及TensorFlow相容,會損失一些對網路架構的自由度,且沒辦法使用到底層套件的全部功能。

MNIST手寫數字資料集

$ pip install matplotlib
# A_Neural_Network_keras.py

from keras.datasets import mnist
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
plt.imshow(train_images[0], cmap=plt.cm.binary)
plt.show()

print(train_labels[0])

MNIST手寫數字資料集

(補充) MNIST手寫數字資料集

# mnist_subplot.py

from keras.datasets import mnist
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
plt.subplot(1, 4, 1)
plt.imshow(train_images[0], cmap=plt.cm.binary)
plt.subplot(1, 4, 2)
plt.imshow(train_images[1], cmap=plt.cm.binary)
plt.subplot(1, 4, 3)
plt.imshow(train_images[2], cmap=plt.cm.binary)
plt.subplot(1, 4, 4)
plt.imshow(train_images[3], cmap=plt.cm.binary)
plt.show()

五個步驟

1. 載入資料

  • x:圖片
  • y:標籤
  • train:訓練集
  • test:測試集

用來訓練模型

對模型進行測試

五個步驟

2. 建構神經網路架構

3. 資料預處理

訓練集 & 測試集

轉換為神經網路能夠處理的形式

4. 模型訓練

學習把圖片加以歸類

訓練集圖片

五個步驟

5. 評估測試正確率

測試集圖片

預測出來的數字

測試集標籤

五個步驟 - 1. 載入資料

# A_Neural_Network_keras.py

from keras.datasets import mnist
from keras import models
from keras import layers
from keras.utils import to_categorical

匯入所需套件

五個步驟 - 1. 載入資料

# A_Neural_Network_keras.py

# ...

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

print(train_labels[0], test_labels[0]) 
# 5 7

載入MNIST

五個步驟 - 2. 建構神經網路架構

relu

softmax

Activation Function

激勵函數

五個步驟 - 2. 建構神經網路架構

# A_Neural_Network_keras.py

# ...

network = models.Sequential()
			#激勵函數:relu, softmax
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))

# 編譯
network.compile(optimizer='rmsprop', 			#優化器
                loss='categorical_crossentropy', 	#損失函數
                metrics=['accuracy']) 			#評量準則

建構神經網路架構

五個步驟 - 3. 資料預處理

# A_Neural_Network_keras.py

# ...

# 訓練集 & 測試集 轉換為神經網路能夠處理的形式
# 並縮放到所有值都在 [0, 1] 區間 (除以255)
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255

test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255

資料預處理

五個步驟 - 3. 資料預處理

  • One-hot encoding 編碼邏輯
    • 類別拆成多個行。
    • 任意維度的向量中,僅一個維度的值是1,其餘爲0。

五個步驟 - 3. 資料預處理

# A_Neural_Network_keras.py

# ...

# 對標籤進行分類編碼(One-hot encoding) => 目標:預測數字
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

print(train_labels[0], test_labels[0]) 
# [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]

One-hot encoding

五個步驟 - 4. 模型訓練

# A_Neural_Network_keras.py

# ...


# epochs:表示訓練遍數
# batch_size:表示每次餵給網路的資料數目
network.fit(train_images, train_labels, epochs=5, batch_size=128)

模型訓練

五個步驟 - 5. 評估測試正確率

# A_Neural_Network_keras.py

# ...

# 檢測在測試集上的正確率
test_loss, test_acc = network.evaluate(test_images, test_labels)
print('測試正確率:', test_acc)

評估測試正確率

五個步驟 - 5. 評估測試正確率

Reference

  • François Chollet, (2019/05/31), Deep learning 深度學習必讀:Keras 大神帶你用 Python 實作.
  • François Chollet, (2019/05/31), A first look at a neural network. Retrieved from https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/2.1-a-first-look-at-a-neural-network.ipynb
  • 周秉誼, (2016/12/20), Deep Learning 開發及常用套件介紹. Retrieved from http://www.cc.ntu.edu.tw/chinese/epaper/0039/20161220_3910.html

Reference

  • CH.Tseng, (2017/09/23), 學習使用Keras建立卷積神經網路. Retrieved from https://chtseng.wordpress.com/2017/09/23/%E5%AD%B8%E7%BF%92%E4%BD%BF%E7%94%A8keras%E5%BB%BA%E7%AB%8B%E5%8D%B7%E7%A9%8D%E7%A5%9E%E7%B6%93%E7%B6%B2%E8%B7%AF/
  • mikechenx, (2017/12/13), Day 03:Neural Network 的概念探討. Retrieved from https://ithelp.ithome.com.tw/users/20001976/ironman
  • 國家實驗研究院, (2017/08/01), TensorFlow 基礎篇〈下〉. Retrieved from http://fgc.stpi.narl.org.tw/activity/videoDetail/4b1141305d9cd231015d9d08fb62002d

Supplement

Thanks for listening.

建構一個類神經網路

By BessyHuang

建構一個類神經網路

  • 532