图像数据-TFrecord在动态图中的使用

本文介绍图片数据使用TFrecord和tf.data.dataset进行存储和读取。

Tensorflow 提供了四种数据读取方式:

  1. Preloaded data: 用一个tf.constant常量将数据集加载进来,主要用于很小的数据集;

  2. Feeding: 使用python代码供给数据,将所有数据加载进内存,然后一个batch一个batch地输入到计算图中, 适用于小数据集;

  3. QueueRunner: 基于队列的输入通道,读取TFrecord静态图使用;
  4. tf.data API: 能够从不同的输入或文件格式中读取、预处理数据,并且对数据应用一些变换(例如,batching、shuffling、mapping function over the dataset),tf.data API 是旧的 feeding、QueueRunner的升级。值得注意的是, Eager模式必须使用该API来构建输入通道, 一般结合TFrecord使用。该API相比于Queue更容易使用。

What‘s TFrecord

TFrecord是Tensorflow提供的一种二进制存储格式,可将数据和标签统一存储。从上述读取方式中可以看出,TFrecord在QueueRunner和tf.data API读取中均扮演了重要的角色。

Why TFrecord

与其他方案相比, 使用TFrecord读取的优点在于:

  1. 可处理大规模数据量,而不会造成其他方案所带来的内存不够用的问题;
  2. 在Feeding方案中,batch读取的IO操作势必会阻塞训练,前一个batch加载完成后,神经网络必须等待下一个batch加载完成后才能继续训练,效率较低。

How To Use

TFrecord的使用主要有两块:一是图片数据转TFrecord格式存储,二是解析存储好的TFrecord文件。下面逐一介绍。

图片转TFrecord

本文使用的数据集是Kaggle猫狗数据集。

该数据集包含train和test两个文件夹, 分别为训练集和测试集,下面以train集为例操作。

1
2
3
ls |wc -w

25000

训练集包含25000张图片,猫狗各一半。

1
2
3
4
$ ls 

cat.124.jpg cat.3750.jpg cat.6250.jpg cat.8751.jpg dog.11250.jpg dog.2500.jpg dog.5000.jpg dog.7501.jpg
...

图片文件以jpg格式存储,以cat, dog作为文件名开头。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
from tqdm import tqdm
import tensorflow as tf


def img_tfrecord_encode(classes, tfrecord_filename, data_path, is_training=True):
"""
功能:读取图片转换成tfrecord格式的文件
@params: classes: 标签类别 @type:classes: dict
@params: tfrecord_filename: tfrecord文件保存文件
@type:tfrecord_filename: str
@params: data_path: 原始训练集存储路径
@is_training: 是否为训练集,用来区分训练集和测试集
"""
# 初始化一个writer
writer = tf.python_io.TFRecordWriter(tfrecord_filename)
for img_name in tqdm(os.listdir(path)):
name = img_name.split('.')[0]
# 使用tf.gfile.FastFile读取图片要比PIL.Image读取处理得到的
# 最终TFrecod文件小得多,在本案例中,IMAGE方式读取得到的TFrecord大小约为3.7G
# 而tf.gfile.FastFile得到的约为548M
with tf.gfile.FastGFile(os.path.join(path, img_name), 'rb') as gf:
img = gf.read()
if is_training:
# 构造特征
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[classes[name]])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),
'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_name.encode()]))
}
else:
feature = {
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[-1])),
'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),
'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_name.encode()]))
}
# example 对象将label和image特征进行封装
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
print('tfrecord writen done!')

调用上述函数,可得到猫狗训练集的TFrecord格式文件

1
2
3
4
5
if __name__ == '__main__':
classes = {'cat': 0, 'dog': 1}
tfrecord_filename = 'cat_and_dog.tfrecord'
data_path = 'train/'
img_tfrecord_encode(classes, tfrecord_filename, data_path, is_training=True)

上述程序运行大约需要2min。

使用tf.data读取TFrecord

在动态图(eager)模式下,QueueRunner不可用,必须使用tf.data进行TFrecord的读取。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

def img_tfrecord_parse(tfrecord_filename, epochs, batch_size, shape,
padded_shapes=None, shuffle=True, buffer_size=1000):
"""
@param: tfrecord_filename:tfrecord文件列表 @type:list
@param: epoch:训练轮数(repeating次数) @type:int
@param:batch_size:批数据大小 @type:int
@param: shape:图片维度 @type:tuple
@param: padded_shapes:不定长padding @type:tuple
@param: shuffle:是否打乱 @type:boolean
"""
# 解析单个example,特征与encode一一对应。
def parse_example(serialized_example):
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
'file_name': tf.FixedLenFeature([], tf.string)
})
# 解码
image = tf.image.decode_jpeg(features['img_raw'])
# 设置shape
image = tf.image.resize_images(image, shape, method=1)
label = tf.cast(features['label'], tf.int64)
file_name = tf.cast(features['file_name'], tf.string)
return image, label, file_name

# 解析TFrecord
dataset = tf.data.TFRecordDataset(tfrecord_filename).map(parse_example)
if shuffle:
if padded_shapes:
dataset = dataset.repeat(epochs).shuffle(buffer_size=buffer_size).padded_batch(batch_size, padded_shapes)
else:
dataset = dataset.repeat(epochs).shuffle(buffer_size=buffer_size).batch(batch_size)
else:
if padded_shapes:
dataset = dataset.repeat(epochs).padded_batch(batch_size, padded_shapes)
else:
dataset = dataset.repeat(epochs).batch(batch_size)
return dataset

调用上述函数,解析TFrecord得到dataset。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
if __name__ == '__main__'():
tfrecord_filename = 'cat_and_dog.tfrecord'
epochs = 100
batch_size = 64
shape = (227, 227)
dataset = img_tfrecord_parse(tfrecord_filename=tfrecord_filename,
epochs=epochs,
batch_size=batch_size,
shape=shape)
# 查看dataset
iterator = dataset.make_one_hot_iterator()
image, label, file_name = iterator.get_next()
print(image[0])
print(label[0])
print(file_name[0])