製作自己的數據集(使用tfrecords)
為什麼採用這個格式?
TFRecords文件格式在圖像識別中有很好的使用,其可以將二進制數據和標籤數據(訓練的類別標籤)數據存儲在同一個文件中,它可以在模型進行訓練之前通過預處理步驟將圖像轉換為TFRecords格式,此格式最大的優點實踐每幅輸入圖像和與之關聯的標籤放在同一個文件中.TFRecords文件是一種二進制文件,其不對數據進行壓縮,所以可以被快速加載到內存中.格式不支持隨機訪問,因此它適合於大量的數據流,但不適用於快速分片或其他非連續存取。
前戲:
tf.train.Feature
tf.train.Feature有三個屬性為tf.train.bytes_list tf.train.float_list tf.train.int64_list,顯然我們只需要根據上一步得到的值來設置tf.train.Feature的屬性就可以了,如下所示:
1 tf.train.Feature(int64_list=data_id)
2 tf.train.Feature(bytes_list=data)
tf.train.Features
從名字來看,我們應該能猜出tf.train.Features是tf.train.Feature的複數,事實上tf.train.Features有屬性為feature,這個屬性的一般設置方法是傳入一個字典,字典的key是字符串(feature名),而值是tf.train.Feature對象。因此,我們可以這樣得到tf.train.Features對象:
1 feature_dict = {
2 "data_id": tf.train.Feature(int64_list=data_id),
3 "data": tf.train.Feature(bytes_list=data)
4 }
5 features = tf.train.Features(feature=feature_dict)
tf.train.Example
終於到我們的主角了。tf.train.Example有一個屬性為features,我們只需要將上一步得到的結果再次當做參數傳進來即可。
另外,tf.train.Example還有一個方法SerializeToString()需要説一下,這個方法的作用是把tf.train.Example對象序列化為字符串,因為我們寫入文件的時候不能直接處理對象,需要將其轉化為字符串才能處理。
當然,既然有對象序列化為字符串的方法,那麼肯定有從字符串反序列化到對象的方法,該方法是FromString(),需要傳遞一個tf.train.Example對象序列化後的字符串進去做為參數才能得到反序列化的對象。
在我們這裏,只需要構建tf.train.Example對象並序列化就可以了,這一步的代碼為:
1 example = tf.train.Example(features=features)
2 example_str = example.SerializeToString()
實例(高潮部分):
首先看一下我們的文件夾路徑:
create_tfrecords.py中寫我們的函數
生成數據文件階段代碼如下:
1 def creat_tf(imgpath):
2 cwd = os.getcwd() #獲取當前路徑
3 classes = os.listdir(cwd + imgpath) #獲取到[1, 2]文件夾
4 # 此處定義tfrecords文件存放
5 writer = tf.python_io.TFRecordWriter("train.tfrecords")
6 for index, name in enumerate(classes): #循環獲取倆文件夾(倆類別)
7 class_path = cwd + imgpath + name + "/"
8 if os.path.isdir(class_path):
9 for img_name in os.listdir(class_path):
10 img_path = class_path + img_name
11 img = Image.open(img_path)
12 img = img.resize((224, 224))
13 img_raw = img.tobytes()
14 example = tf.train.Example(features=tf.train.Features(feature={
15 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(name)])),
16 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
17 }))
18 writer.write(example.SerializeToString())
19 print(img_name)
20 writer.close()
這段代碼主要生成 train.tfrecords 文件。
讀取數據階段代碼如下:
1 def read_and_decode(filename):
2 # 根據文件名生成一個隊列
3 filename_queue = tf.train.string_input_producer([filename])
4
5 reader = tf.TFRecordReader()
6 _, serialized_example = reader.read(filename_queue) # 返回文件名和文件
7 features = tf.parse_single_example(serialized_example,
8 features={
9 'label': tf.FixedLenFeature([], tf.int64),
10 'img_raw': tf.FixedLenFeature([], tf.string),
11 })
12
13 img = tf.decode_raw(features['img_raw'], tf.uint8)
14 img = tf.reshape(img, [224, 224, 3])
15 # 轉換為float32類型,並做歸一化處理
16 img = tf.cast(img, tf.float32) # * (1. / 255)
17 label = tf.cast(features['label'], tf.int64)
18 return img, label
訓練階段我們獲取數據的代碼:
1 images, labels = read_and_decode('./train.tfrecords')
2 img_batch, label_batch = tf.train.shuffle_batch([images, labels],
3 batch_size=5,
4 capacity=392,
5 min_after_dequeue=200)
6 init = tf.global_variables_initializer()
7 with tf.Session() as sess:
8 sess.run(init)
9 coord = tf.train.Coordinator() #線程協調器
10 threads = tf.train.start_queue_runners(sess=sess,coord=coord)
11 # 訓練部分代碼--------------------------------
12 IMG, LAB = sess.run([img_batch, label_batch])
13 print(IMG.shape)
14
15 #----------------------------------------------
16 coord.request_stop() # 協調器coord發出所有線程終止信號
17 coord.join(threads) #把開啓的線程加入主線程,等待threads結束
總結(流程):
- 生成tfrecord文件
- 定義
record reader解析tfrecord文件 - 構造一個批生成器(
batcher) - 構建其他的操作
- 初始化所有的操作
- 啓動
QueueRunner
備註:關於tf.train.Coordinator 詳見:
TensorFlow的Session對象是支持多線程的,可以在同一個會話(Session)中創建多個線程,並行執行。在Session中的所有線程都必須能被同步終止,異常必須能被正確捕獲並報告,會話終止的時候, 隊列必須能被正確地關閉。
- 調用 tf.train.slice_input_producer,從 本地文件裏抽取tensor,準備放入Filename Queue(文件名隊列)中;
- 調用 tf.train.batch,從文件名隊列中提取tensor,使用單個或多個線程,準備放入文件隊列;
- 調用 tf.train.Coordinator() 來創建一個線程協調器,用來管理之後在Session中啓動的所有線程;
- 調用tf.train.start_queue_runners, 啓動入隊線程,由多個或單個線程,按照設定規則,把文件讀入Filename Queue中。函數返回線程ID的列表,一般情況下,系統有多少個核,就會啓動多少個入隊線程(入隊具體使用多少個線程在tf.train.batch中定義);
- 文件從 Filename Queue中讀入內存隊列的操作不用手動執行,由tf自動完成;
- 調用sess.run 來啓動數據出列和執行計算;
- 使用 coord.should_stop()來查詢是否應該終止所有線程,當文件隊列(queue)中的所有文件都已經讀取出列的時候,會拋出一個 OutofRangeError 的異常,這時候就應該停止Sesson中的所有線程了;
- 使用coord.request_stop()來發出終止所有線程的命令,使用coord.join(threads)把線程加入主線程,等待threads結束。