Cifar-10 的所有圖片被分為 10 個類別 (以 0~9 數字作為 Label 之編碼) :
- 0 : airplain (飛機)
- 1 : automobile (汽車)
- 2 : bird (鳥)
- 3 : cat (貓)
- 4 : deer (鹿)
- 5 : dog (狗)
- 6 : frog (青蛙)
- 7 : horse (馬)
- 8 : ship (船)
- 9 : truck (卡車)
# https://en.wikipedia.org/wiki/CIFAR-10
# https://www.cs.toronto.edu/~kriz/cifar.html (可下載 Cifar-10 資料集)
Cifar-10 名稱來自加拿大高等研究院 (Canadian Institute For Advanced Research), 10 表示其包含 10 種類別圖片. Cifar-10 事實上是一個包含 8000 萬個已標記 (Labeled) 圖庫的子集合, 它還有一個更大的姊妹 Cifar-100 資料集, 同樣包含 60000 個圖片, 但有100 種類別.
Keras 有提供處理 Cifar-10 資料集之模組 cifar10, 可利用 Keras 建構機器學習模型, 利用 5 萬筆訓練集圖片訓練模型中之參數, 然後用訓練好的模型來預測 1 萬筆測試集中的圖片屬於 10 種類別中的哪一種. 以下是依據林大貴的 "TensorFlow+Keras 深度學習人工智慧實務應用" 第 9 章進行測試並記錄結果.
本系列之前的測試紀錄參考 :
# Windows 安裝深度學習框架 TensorFlow 與 Keras
# 使用 Keras 測試 MNIST 手寫數字辨識資料集
# 使用 Keras 多層感知器 MLP 辨識手寫數字 (一)
# 使用 Keras 多層感知器 MLP 辨識手寫數字 (二)
# 使用 Keras 卷積神經網路 (CNN) 辨識手寫數字
1. 匯入 Keras 的 cifar10 模組
D:\Python\test>python
Python 3.6.1 (v3.6.1:69c0db5, Mar 21 2017, 18:41:36) [MSC v.1900 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from keras.datasets import cifar10
Using TensorFlow backend.
>>> import numpy as np
>>> np.random.seed(10)
2. 載入 Cifar-10 資料集
呼叫 cifar10.load_data() 即自動從 Alex Krixhevsky 在多倫多大學的 Cifar-10 網站下載資料集檔案 cifar-10-python.tar.gz 至 C:\使用者目錄的 .keras 子目錄下, 並自動將資料集解壓縮至 cifar-10-batches-py 子目錄下 :
>>> (x_train_image, y_train_label), (x_test_image, y_test_label)=cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 371s 2us/step
在呼叫 load_data() 後觀察 Windows C:\使用者目錄 (此處為 Tony) 下的 .keras 目子目錄, 可以發現下載完成後會在 .keras 下自動產生一個子目錄 cifar-10-batches-py 來存放解出來的資料集:
cifar10.load_data() 的傳回值為訓練集/測試集數字圖片陣列與其標籤陣列所組成之 tuple, 利用陣列的方法就可以取用 Cifar-10 資料集中的圖片了.
與 mnist.load_data() 不同的是, cifar10.load_data() 會固定去 Cifar-10 網頁下載資料集, 如果因為防火牆的阻擋而無法下載資料集的話, 就會出現 "連線嘗試失敗" 的錯誤訊息. 即使在 Cifar-10 網頁下載資料集的壓縮檔 (有 Python, Matlab, 以及 C 語言用的二進檔等三種版本, 我下載的是 Python 版的 cifar-10-python.tar.gz, 大約 163 MB) 自行解壓縮到上述之 .keras 目錄, 呼叫 load_data() 不會去檢查 .keras 目錄下是否已有 Cifar-10 資料集, 因此還是出現連線錯誤訊息 :
>>> (x_train_image, y_train_label), (x_test_image, y_test_label)=cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Traceback (most recent call last):
File "C:\Python36\lib\urllib\request.py", line 1318, in do_open
encode_chunked=req.has_header('Transfer-encoding'))
File "C:\Python36\lib\http\client.py", line 1239, in request
self._send_request(method, url, body, headers, encode_chunked)
File "C:\Python36\lib\http\client.py", line 1285, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "C:\Python36\lib\http\client.py", line 1234, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "C:\Python36\lib\http\client.py", line 1026, in _send_output
self.send(msg)
File "C:\Python36\lib\http\client.py", line 964, in send
self.connect()
File "C:\Python36\lib\http\client.py", line 1392, in connect
super().connect()
File "C:\Python36\lib\http\client.py", line 936, in connect
(self.host,self.port), self.timeout, self.source_address)
File "C:\Python36\lib\socket.py", line 722, in create_connection
raise err
File "C:\Python36\lib\socket.py", line 713, in create_connection
sock.connect(sa)
TimeoutError: [WinError 10060] 連線嘗試失敗,因為連線對象有一段時間並未正確回應,或是連線建立失敗,因為連線的主機無法回應。
3. 查詢資料集
利用陣列的 len() 函數可以查詢陣列長度 :
>>> print('train image numbers=', len(x_train_image)) #顯示訓練圖片筆數 : 5 萬筆
train image numbers= 50000
>>> print('train label numbers=', len(y_train_label)) #顯示訓練標籤筆數 : 5 萬筆
train label numbers= 50000
>>> print('test image numbers=', len(x_test_image)) #顯示訓練標籤筆數 : 1 萬筆
test image numbers= 10000
>>> print('test label numbers=', len(y_test_label)) #顯示測試標籤筆數 : 1 萬筆
test label numbers= 10000
利用陣列的 shape 屬性可查詢文字圖片的外型屬性, 包含筆數, 解析度, 以及色版數目 :
>>> print('x_train_image:', x_train_image.shape) #顯示訓練集圖片之 shape
x_train_image: (50000, 32, 32, 3) #訓練集為 5 萬筆之 32*32 RGB 彩色圖片
>>> print('y_train_label:', y_train_label.shape) #顯示訓練集標籤之 shape
y_train_label: (50000, 1) #訓練集標籤為 5 萬筆 0~9 數字
>>> print('x_test_image:', x_test_image.shape) #顯示測試集圖片之 shape
x_test_image: (10000, 32, 32, 3) #測試集為 1 萬筆之 32*32 RGB 彩色圖片
>>> print('y_test_label:', y_test_label.shape) #顯示測試集標籤之 shape
y_test_label: (10000, 1) #測試集標籤為 1 萬筆 0~9 數字
可見訓練集之數字圖片陣列共有 50000 筆 32*32 解析度的彩色 RGB 圖片 (色版=3); 而測試集則有 10000 筆 32*32 解析度的彩色 RGB 圖片, 不論是訓練集還是測試集, 傳回值 x_train_image 與 x_test_image 均為 3 維陣列, 利用索引 0~49999 可以查詢 5 萬筆訓練集圖片之內容, 測試集圖片的索引範圍則為 0~9999. 例如訓練集的第一張圖片資料為陣列 x_train_image[0] :
>>> print('x_train_image[0]=', x_train_image[0])
x_train_image[0]= [
[[ 59 62 63] #第 1 列畫素開始 (0)
[ 43 46 45]
[ 50 48 43]
...
[158 132 108]
[152 125 102]
[148 124 103]] #第 1 列畫素結束 (31)
[[ 16 20 20] #第 2 列畫素開始 (0)
[ 0 0 0]
[ 18 8 0]
...
[123 88 55]
[119 83 50]
[122 87 57]] #第 2 列畫素結束 (31)
[[ 25 24 21]
[ 16 7 0]
[ 49 27 8]
...
[118 84 50]
[120 84 50]
[109 73 42]]
...
[[208 170 96]
[201 153 34]
[198 161 26]
...
[160 133 70]
[ 56 31 7]
[ 53 34 20]]
[[180 139 96]
[173 123 42]
[186 144 30]
...
[184 148 94]
[ 97 62 34]
[ 83 53 34]]
[[177 144 116] #第 32 列畫素開始 (0)
[168 129 94]
[179 142 87]
...
[216 184 140]
[151 118 84]
[123 92 72]]] #第 32 列畫素結束 (0)
由於資料太長, 所以輸出被自動節略了. 不過從輸出可知, 每一個畫素以 1*3 向量 [R G B] 形式來表示, 第一維表示列, 後兩維表示每一列畫素, 因此一張圖總共有 32*32=1024 個 1 維向量, 有 32*32*3=3072 個數字. 訓練集第一張圖的第一列畫素為 x_train_image[0][0], 由 32 個一維向量組成 :
>>> print('x_train_image[0][0]=', x_train_image[0][0])
x_train_image[0][0]= [[ 59 62 63]
[ 43 46 45]
[ 50 48 43]
[ 68 54 42]
[ 98 73 52]
[119 91 63]
[139 107 75]
[145 110 80]
[149 117 89]
[149 120 93]
[131 103 77]
[125 99 76]
[142 115 91]
[144 112 86]
[137 105 79]
[129 97 71]
[137 106 79]
[134 106 76]
[124 97 64]
[139 113 78]
[139 112 75]
[133 105 69]
[136 105 74]
[139 108 77]
[152 120 89]
[163 131 100]
[168 136 108]
[159 129 102]
[158 130 104]
[158 132 108]
[152 125 102]
[148 124 103]]
訓練集第一張圖片的第一個畫素 x_train_image[0][0][0] 如下 :
>>> print('x_train_image[0][0][0]=', x_train_image[0][0][0])
x_train_image[0][0][0]= [59 62 63]
4. 顯示訓練集圖片
顯示 Cifar-10 資料集中的圖片可利用 matplotlib.pyplot 模組的 imshow() 函數, 參考之前 MNIST 測試中以 imshow() 為主的自訂繪圖函數 plot_image() :
# 使用 Keras 測試 MNIST 手寫數字辨識資料集 (步驟 6)
先自訂 plot_image() 函數再呼叫它來顯示訓練集的第一張圖片與其標籤 :
>>> import matplotlib.pyplot as plt
>>> def plot_image(image):
... fig=plt.gcf()
... fig.set_size_inches(2, 2)
... plt.imshow(image, cmap='binary')
... plt.show()
...
>>> print(y_train_label[0]) #第一張圖片為類別 6 之青蛙
[6]
>>> plot_image(x_train_image[0]) #顯示第一張圖片
將此圖放大後可清楚看出此青蛙圖片是由 32*32 的畫素組成, 每一個畫素之顏色即由上述之一維向量中的 RGB 三原色所決定 :
我將上面的顯示圖片指令寫成如下可在命令列執行之程式檔 :
#show_cifar10_train_image0.py
from keras.datasets import cifar10
import matplotlib.pyplot as plt
def plot_image(image): #自訂繪圖函數
fig=plt.gcf() #取得 pyplot 物件參考
fig.set_size_inches(2, 2) #設定畫布大小為 2 吋*2吋
plt.imshow(image, cmap='binary') #以 binary (灰階) 顯示 28*28 圖形
plt.show() #顯示圖形
(x_train_image, y_train_label), \
(x_test_image, y_test_label)=cifar10.load_data() #載入 MNIST 資料集
print(y_train_label[0]) #顯示第一筆樣本之標籤 (label)
plot_image(x_train_image[0]) #繪製第一筆樣本之圖形
執行結果如上圖 :
D:\Python\test>python show_cifar10_train_image0.py
Using TensorFlow backend.
[6]
可同時顯示多張 Cifar-10 圖片的程式改編如下 :
#show_cifar10_train_images.py
import sys
from keras.datasets import cifar10
import matplotlib.pyplot as plt
label_dict={0:"airplain",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog",
6:"frog",7:"horse",8:"ship",9:"truck"} #轉換標籤為類別名稱用
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):
fig=plt.gcf() #取得 pyplot 物件參考
fig.set_size_inches(12, 14) #設定畫布大小為 12 吋*14吋
if num > 25: num=25 #限制最多顯示 25 個子圖
for i in range(0, num): #依序顯示 num 個子圖
ax=plt.subplot(5, 5, i+1) #建立 5*5 個子圖中的第 i+1 個
ax.imshow(images[idx], cmap='binary') #顯示子圖
title=str(idx) + "." + label_dict[labels[idx][0]] + str(labels[idx])
if len(prediction) > 0: #有預測值就加入標題中
title += ",predict=" + str(prediction[idx])
ax.set_title(title, fontsize=10) #設定標題
ax.set_xticks([]); #不顯示 x 軸刻度
ax.set_yticks([]); #不顯示 y 軸刻度
idx += 1 #樣本序號增量 1
plt.show() #繪製圖形
(x_train_image, y_train_label), \
(x_test_image, y_test_label)=cifar10.load_data() #載入 Cifar-10 資料集
i=int(sys.argv[1]) #取得第一個命令列參數 ()
j=int(sys.argv[2]) #取得第二個命令列參數
plot_images_labels_prediction(x_train_image,y_train_label,[],i,j) #無預測值
與之前 MNIST 時不同之處是 title 顯示的部分, 串接了圖片索引, 類別名稱, 以及標籤, 其中類別名稱是利用 label_dict 字典將標籤轉成名稱.
以下是顯示訓練集第一張開始的 25 張圖片 :;
D:\Python\test>python show_cifar10_train_images.py 0 25
Using TensorFlow backend.
訓練集最後 25 張圖片開始索引為 49975 :
D:\Python\test>python show_cifar10_train_images.py 49975 25
Using TensorFlow backend.
Cifar-10 的彩色圖片比單調的 MNIST 要賞心悅目多了.
1. 匯入 Keras 的 cifar10 模組
D:\Python\test>python
Python 3.6.1 (v3.6.1:69c0db5, Mar 21 2017, 18:41:36) [MSC v.1900 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from keras.datasets import cifar10
Using TensorFlow backend.
>>> import numpy as np
>>> np.random.seed(10)
2. 載入 Cifar-10 資料集
呼叫 cifar10.load_data() 即自動從 Alex Krixhevsky 在多倫多大學的 Cifar-10 網站下載資料集檔案 cifar-10-python.tar.gz 至 C:\使用者目錄的 .keras 子目錄下, 並自動將資料集解壓縮至 cifar-10-batches-py 子目錄下 :
>>> (x_train_image, y_train_label), (x_test_image, y_test_label)=cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 371s 2us/step
在呼叫 load_data() 後觀察 Windows C:\使用者目錄 (此處為 Tony) 下的 .keras 目子目錄, 可以發現下載完成後會在 .keras 下自動產生一個子目錄 cifar-10-batches-py 來存放解出來的資料集:
切換至 cifar-10-batches-py 子目錄可見有 8 個檔案, 其中 data_batch_1~5 為總數 5 萬筆之訓練集 (每一個檔案各 1 萬筆), 而 test_batch 為 1 萬筆之測試集 :
cifar10.load_data() 的傳回值為訓練集/測試集數字圖片陣列與其標籤陣列所組成之 tuple, 利用陣列的方法就可以取用 Cifar-10 資料集中的圖片了.
與 mnist.load_data() 不同的是, cifar10.load_data() 會固定去 Cifar-10 網頁下載資料集, 如果因為防火牆的阻擋而無法下載資料集的話, 就會出現 "連線嘗試失敗" 的錯誤訊息. 即使在 Cifar-10 網頁下載資料集的壓縮檔 (有 Python, Matlab, 以及 C 語言用的二進檔等三種版本, 我下載的是 Python 版的 cifar-10-python.tar.gz, 大約 163 MB) 自行解壓縮到上述之 .keras 目錄, 呼叫 load_data() 不會去檢查 .keras 目錄下是否已有 Cifar-10 資料集, 因此還是出現連線錯誤訊息 :
>>> (x_train_image, y_train_label), (x_test_image, y_test_label)=cifar10.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Traceback (most recent call last):
File "C:\Python36\lib\urllib\request.py", line 1318, in do_open
encode_chunked=req.has_header('Transfer-encoding'))
File "C:\Python36\lib\http\client.py", line 1239, in request
self._send_request(method, url, body, headers, encode_chunked)
File "C:\Python36\lib\http\client.py", line 1285, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "C:\Python36\lib\http\client.py", line 1234, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "C:\Python36\lib\http\client.py", line 1026, in _send_output
self.send(msg)
File "C:\Python36\lib\http\client.py", line 964, in send
self.connect()
File "C:\Python36\lib\http\client.py", line 1392, in connect
super().connect()
File "C:\Python36\lib\http\client.py", line 936, in connect
(self.host,self.port), self.timeout, self.source_address)
File "C:\Python36\lib\socket.py", line 722, in create_connection
raise err
File "C:\Python36\lib\socket.py", line 713, in create_connection
sock.connect(sa)
TimeoutError: [WinError 10060] 連線嘗試失敗,因為連線對象有一段時間並未正確回應,或是連線建立失敗,因為連線的主機無法回應。
3. 查詢資料集
利用陣列的 len() 函數可以查詢陣列長度 :
>>> print('train image numbers=', len(x_train_image)) #顯示訓練圖片筆數 : 5 萬筆
train image numbers= 50000
>>> print('train label numbers=', len(y_train_label)) #顯示訓練標籤筆數 : 5 萬筆
train label numbers= 50000
>>> print('test image numbers=', len(x_test_image)) #顯示訓練標籤筆數 : 1 萬筆
test image numbers= 10000
>>> print('test label numbers=', len(y_test_label)) #顯示測試標籤筆數 : 1 萬筆
test label numbers= 10000
利用陣列的 shape 屬性可查詢文字圖片的外型屬性, 包含筆數, 解析度, 以及色版數目 :
>>> print('x_train_image:', x_train_image.shape) #顯示訓練集圖片之 shape
x_train_image: (50000, 32, 32, 3) #訓練集為 5 萬筆之 32*32 RGB 彩色圖片
>>> print('y_train_label:', y_train_label.shape) #顯示訓練集標籤之 shape
y_train_label: (50000, 1) #訓練集標籤為 5 萬筆 0~9 數字
>>> print('x_test_image:', x_test_image.shape) #顯示測試集圖片之 shape
x_test_image: (10000, 32, 32, 3) #測試集為 1 萬筆之 32*32 RGB 彩色圖片
>>> print('y_test_label:', y_test_label.shape) #顯示測試集標籤之 shape
y_test_label: (10000, 1) #測試集標籤為 1 萬筆 0~9 數字
可見訓練集之數字圖片陣列共有 50000 筆 32*32 解析度的彩色 RGB 圖片 (色版=3); 而測試集則有 10000 筆 32*32 解析度的彩色 RGB 圖片, 不論是訓練集還是測試集, 傳回值 x_train_image 與 x_test_image 均為 3 維陣列, 利用索引 0~49999 可以查詢 5 萬筆訓練集圖片之內容, 測試集圖片的索引範圍則為 0~9999. 例如訓練集的第一張圖片資料為陣列 x_train_image[0] :
>>> print('x_train_image[0]=', x_train_image[0])
x_train_image[0]= [
[[ 59 62 63] #第 1 列畫素開始 (0)
[ 43 46 45]
[ 50 48 43]
...
[158 132 108]
[152 125 102]
[148 124 103]] #第 1 列畫素結束 (31)
[ 0 0 0]
[ 18 8 0]
...
[123 88 55]
[119 83 50]
[122 87 57]] #第 2 列畫素結束 (31)
[[ 25 24 21]
[ 16 7 0]
[ 49 27 8]
...
[118 84 50]
[120 84 50]
[109 73 42]]
...
[[208 170 96]
[201 153 34]
[198 161 26]
...
[160 133 70]
[ 56 31 7]
[ 53 34 20]]
[[180 139 96]
[173 123 42]
[186 144 30]
...
[184 148 94]
[ 97 62 34]
[ 83 53 34]]
[[177 144 116] #第 32 列畫素開始 (0)
[168 129 94]
[179 142 87]
...
[216 184 140]
[151 118 84]
[123 92 72]]] #第 32 列畫素結束 (0)
由於資料太長, 所以輸出被自動節略了. 不過從輸出可知, 每一個畫素以 1*3 向量 [R G B] 形式來表示, 第一維表示列, 後兩維表示每一列畫素, 因此一張圖總共有 32*32=1024 個 1 維向量, 有 32*32*3=3072 個數字. 訓練集第一張圖的第一列畫素為 x_train_image[0][0], 由 32 個一維向量組成 :
>>> print('x_train_image[0][0]=', x_train_image[0][0])
x_train_image[0][0]= [[ 59 62 63]
[ 43 46 45]
[ 50 48 43]
[ 68 54 42]
[ 98 73 52]
[119 91 63]
[139 107 75]
[145 110 80]
[149 117 89]
[149 120 93]
[131 103 77]
[125 99 76]
[142 115 91]
[144 112 86]
[137 105 79]
[129 97 71]
[137 106 79]
[134 106 76]
[124 97 64]
[139 113 78]
[139 112 75]
[133 105 69]
[136 105 74]
[139 108 77]
[152 120 89]
[163 131 100]
[168 136 108]
[159 129 102]
[158 130 104]
[158 132 108]
[152 125 102]
[148 124 103]]
訓練集第一張圖片的第一個畫素 x_train_image[0][0][0] 如下 :
>>> print('x_train_image[0][0][0]=', x_train_image[0][0][0])
x_train_image[0][0][0]= [59 62 63]
向量中的 3 個數值分別為 RGB 顏色之色碼 (0~255), 一張圖片由三個色版相疊而成, 訓練集第一張圖的數字圖片結構之示意圖如下 :
標籤 (Label) 是二維陣列結構, 訓練集之標籤如下 (共 5 萬筆), 可以用單索引或雙索引擷取 :
>>> print('y_train_label=', y_train_label) #顯示全部訓練集標籤
y_train_label= [[6]
[9]
[9]
...
[9]
[1]
[1]]
>>> print('y_train_label[0][0]=', y_train_label[0][0]) #第一張圖片是分類 6 (青蛙)
y_train_label[0][0]= 6
>>> print('y_train_label[49999][0]=', y_train_label[49999][0]) #雙索引
y_train_label[49999][0]= 1
>>> print('y_train_label[49999]=', y_train_label[49999]) #單索引
y_train_label[49999]= [1]
4. 顯示訓練集圖片
顯示 Cifar-10 資料集中的圖片可利用 matplotlib.pyplot 模組的 imshow() 函數, 參考之前 MNIST 測試中以 imshow() 為主的自訂繪圖函數 plot_image() :
# 使用 Keras 測試 MNIST 手寫數字辨識資料集 (步驟 6)
先自訂 plot_image() 函數再呼叫它來顯示訓練集的第一張圖片與其標籤 :
>>> import matplotlib.pyplot as plt
>>> def plot_image(image):
... fig=plt.gcf()
... fig.set_size_inches(2, 2)
... plt.imshow(image, cmap='binary')
... plt.show()
...
>>> print(y_train_label[0]) #第一張圖片為類別 6 之青蛙
[6]
>>> plot_image(x_train_image[0]) #顯示第一張圖片
將此圖放大後可清楚看出此青蛙圖片是由 32*32 的畫素組成, 每一個畫素之顏色即由上述之一維向量中的 RGB 三原色所決定 :
#show_cifar10_train_image0.py
from keras.datasets import cifar10
import matplotlib.pyplot as plt
def plot_image(image): #自訂繪圖函數
fig=plt.gcf() #取得 pyplot 物件參考
fig.set_size_inches(2, 2) #設定畫布大小為 2 吋*2吋
plt.imshow(image, cmap='binary') #以 binary (灰階) 顯示 28*28 圖形
plt.show() #顯示圖形
(x_train_image, y_train_label), \
(x_test_image, y_test_label)=cifar10.load_data() #載入 MNIST 資料集
print(y_train_label[0]) #顯示第一筆樣本之標籤 (label)
plot_image(x_train_image[0]) #繪製第一筆樣本之圖形
執行結果如上圖 :
D:\Python\test>python show_cifar10_train_image0.py
Using TensorFlow backend.
[6]
可同時顯示多張 Cifar-10 圖片的程式改編如下 :
#show_cifar10_train_images.py
import sys
from keras.datasets import cifar10
import matplotlib.pyplot as plt
label_dict={0:"airplain",1:"automobile",2:"bird",3:"cat",4:"deer",5:"dog",
6:"frog",7:"horse",8:"ship",9:"truck"} #轉換標籤為類別名稱用
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):
fig=plt.gcf() #取得 pyplot 物件參考
fig.set_size_inches(12, 14) #設定畫布大小為 12 吋*14吋
if num > 25: num=25 #限制最多顯示 25 個子圖
for i in range(0, num): #依序顯示 num 個子圖
ax=plt.subplot(5, 5, i+1) #建立 5*5 個子圖中的第 i+1 個
ax.imshow(images[idx], cmap='binary') #顯示子圖
title=str(idx) + "." + label_dict[labels[idx][0]] + str(labels[idx])
if len(prediction) > 0: #有預測值就加入標題中
title += ",predict=" + str(prediction[idx])
ax.set_title(title, fontsize=10) #設定標題
ax.set_xticks([]); #不顯示 x 軸刻度
ax.set_yticks([]); #不顯示 y 軸刻度
idx += 1 #樣本序號增量 1
plt.show() #繪製圖形
(x_train_image, y_train_label), \
(x_test_image, y_test_label)=cifar10.load_data() #載入 Cifar-10 資料集
i=int(sys.argv[1]) #取得第一個命令列參數 ()
j=int(sys.argv[2]) #取得第二個命令列參數
plot_images_labels_prediction(x_train_image,y_train_label,[],i,j) #無預測值
與之前 MNIST 時不同之處是 title 顯示的部分, 串接了圖片索引, 類別名稱, 以及標籤, 其中類別名稱是利用 label_dict 字典將標籤轉成名稱.
以下是顯示訓練集第一張開始的 25 張圖片 :;
D:\Python\test>python show_cifar10_train_images.py 0 25
Using TensorFlow backend.
訓練集最後 25 張圖片開始索引為 49975 :
D:\Python\test>python show_cifar10_train_images.py 49975 25
Using TensorFlow backend.
Cifar-10 的彩色圖片比單調的 MNIST 要賞心悅目多了.
沒有留言:
張貼留言