Tensorflow中使⽤tfrecord⽅式读取数据的⽅法
前⾔
本博客默认读者对神经⽹络与Tensorflow有⼀定了解,对其中的⼀些术语不再做具体解释。并且本博客主要以图⽚数据为例进⾏介绍,如有错误,敬请斧正。
使⽤Tensorflow训练神经⽹络时,我们可以⽤多种⽅式来读取⾃⼰的数据。如果数据集⽐较⼩,⽽且内存⾜够⼤,可以选择直接将所有数据读进内存,然后每次取⼀个batch的数据出来。如果数据较多,可以每次直接从硬盘中进⾏读取,不过这种⽅式的读取效率就⽐较低了。此篇博客就主要讲⼀下Tensorflow官⽅推荐的⼀种较为⾼效的数据读取⽅式——tfrecord。
从宏观来讲,tfrecord其实是⼀种数据存储形式。使⽤tfrecord时,实际上是先读取原⽣数据,然后转换成tfrecord格式,再存储在硬盘上。⽽使⽤时,再把数据从相应的tfrecord⽂件中解码读取出来。那么使⽤tfrecord和直接从硬盘读取原⽣数据相⽐到底有什么优势呢?其实,Tensorflow有和tfrecord配套的⼀些函数,可以加快数据的处理。实际读取tfrecord数据时,先以相应的tfrecord⽂件为参数,创建⼀个输⼊队列,这个队列有⼀定的容量(视具体硬件限制,⽤户可以设置不同的值),在⼀部分数据出队列时,tfrecord中的其他数据就可以通过预取进⼊队列,并且这个过程和⽹络的计算是独⽴进⾏的。也就是说,⽹络每⼀个iteration的训练不必等待数据队列准备好再开始,队列中的数据始终是充⾜的,⽽往队列中填
充数据时,也可以使⽤多线程加速。
下⾯,本⽂将从以下4个⽅⾯对tfrecord进⾏介绍:
1. tfrecord格式简介
2. 利⽤⾃⼰的数据⽣成tfrecord⽂件
3. 从tfrecord⽂件读取数据
4. 实例测试
1. tfrecord格式简介
这部分主要参考了另⼀篇博⽂,
tfecord⽂件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下⾯是tf.train.Example的定义
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
从上述代码可以看出,tf.train.Example 的数据结构很简单。tf.train.Example中包含了⼀个从属性名称到取值的字典,其中属性名称为⼀个字符串,属性的取值可以为字符串(BytesList ),浮点数列表(FloatList )或整数列表(Int64List )。例如我们可以将图⽚转换为字符串进⾏存储,图像对应的类别标号作为整数存储,⽽⽤于回归任务的ground-truth可以作为浮点数存储。通过后⾯的代码我们会对tfrecord的这种字典形式有更直观的认识。
2. 利⽤⾃⼰的数据⽣成tfrecord⽂件
先上⼀段代码,然后我再针对代码进⾏相关介绍。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
import scipy.io as sio
def _bytes_feature(value):
ain.Feature(bytes_list = tf.train.BytesList(value=[value]))
def _int64_feature(value):
ain.Feature(int64_list = tf.train.Int64List(value=[value]))
root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecords_filename = root_path + 'tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
height = 300
width = 300
meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
meanvalue = meanfile['mean']
txtfile = root_path + ''
fr = open(txtfile)
for i adlines():
item = i.split()
img = np.float64(misc.imread(root_path + '/images/train_images/' + item[0]))
img = img - meanvalue
maskmat = sio.loadmat(root_path + '/mats/train_mats/' + item[1])
mask = np.float64(maskmat['seg_mask'])
label = int(item[2])
img_raw = string()
mask_raw = string()
example = tf.train.Example(ain.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'name': _bytes_feature(item[0]),
'image_raw': _bytes_feature(img_raw),
'mask_raw': _bytes_feature(mask_raw),
'label': _int64_feature(label)}))
writer.write(example.SerializeToString())
writer.close()
fr.close()
代码中前两个函数(_bytes_feature和_int64_feature)是将我们的原⽣数据进⾏转换⽤的,尤其是图⽚要转换成字符串再进⾏存储。这两个函数的定义来⾃官⽅的⽰例。
接下来,我定义了数据的(路径-label⽂件)txtfile,它⼤概长这个样⼦:
这⾥稍微啰嗦下,介绍⼀下我的实验内容。我做的是⼀个multi-task的实验,⼀⽀task做分割,⼀⽀task做分类。所以txtfile中每⼀⾏是⼀个样本,每个样本⼜包含3项,第⼀项为图⽚名称,第⼆项为相应的ground-truth segmentation mask的名称,第三项是图⽚的标签。(txtfile中内容形式⽆所谓,只要能读到想读的数据就可以)
接着回到主题继续讲代码,之后我⼜定义了即将⽣成的tfrecord的⽂件路径和名称,即tfrecord_filename,还有⼀个writer,这个writer是进⾏写操作⽤的。
接下来是图⽚的⾼度、宽度以及我事先在整个数据集上计算好的图像均值⽂件。⾼度、宽度其实完全没必要引⼊,这⾥只是为了说明tfrecord的⽣成⽽写的。⽽均值⽂件是为了对图像进⾏事先的去均值化操作⽽引⼊的,在⼤多数机器学习任务中,图像去均值化对提⾼算法的性能还是很有帮助的。
最后就是根据txtfile中的每⼀⾏进⾏相关数据的读取、转换以及tfrecord的⽣成了。⾸先是根据图⽚路径读取图⽚内容,然后图像减去之前读⼊的均值,接着根据segmentation mask的路径读取mask(如果只是图像分类任务,那么就不会有这些额外的mask),txtfile中的label读出来是string格式,这⾥要转换成int。然后图像和mask数据也要⽤相应的tosring函数转换成string。
真正的核⼼是下⾯这⼀⼩段代码:
example = tf.train.Example(ain.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'name': _bytes_feature(item[0]),
session如何设置和读取'image_raw': _bytes_feature(img_raw),
'mask_raw': _bytes_feature(mask_raw),
'label': _int64_feature(label)}))
writer.write(example.SerializeToString())
这⾥很好地体现了tfrecord的字典特性,tfrecord中每⼀个样本都是⼀个⼩字典,这个字典可以包含任意多个键值对。⽐如我这⾥就存储了图⽚的⾼度、宽度、图⽚名称、图⽚内容、mask内容以及图⽚的label。对于我的任务来说,其实height、width、name都不是必需的,这⾥仅仅是为了展⽰。键值对的键全都是字符串,键起什么名字都可以,只要能⽅便以后使⽤就可以。定义好⼀个example后就可以⽤之前的writer来把它真正写⼊tfrecord⽂件了,这其实就跟把⼀⾏内容写⼊⼀个txt⽂件⼀样。代码的最后就是writer和txt⽂件对象的关闭了。
最后在指定⽂件夹下,就得到了指定名字的tfrecord⽂件,如下所⽰:
需要注意的是,⽣成的tfrecord⽂件⽐原⽣数据的⼤⼩还要⼤,这是正常现象。这种现象可能是因为图⽚⼀般都存储为jpg等压缩格式,⽽tfrecord⽂件存储的是解压后的数据。
3. 从tfrecord⽂件读取数据
还是代码先⾏。
from scipy import misc
import tensorflow as tf
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'
def read_and_decode(filename_queue, random_crop=False, random_clip=False, shuffle_batch=True):
reader = tf.TFRecordReader()
_, serialized_example = ad(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'name': tf.FixedLenFeature([], tf.string),
'image_raw': tf.FixedLenFeature([], tf.string),
'mask_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image_raw'], tf.float64)
image = tf.reshape(image, [300,300,3])
mask = tf.decode_raw(features['mask_raw'], tf.float64)
mask = tf.reshape(mask, [300,300])
name = features['name']
label = features['label']
width = features['width']
height = features['height']
#  if random_crop:
#    image = tf.random_crop(image, [227, 227, 3])
#  else:
#    image = size_image_with_crop_or_pad(image, 227, 227)
#  if random_clip:
#    image = tf.image.random_flip_left_right(image)
if shuffle_batch:
images, masks, names, labels, widths, heights = tf.train.shuffle_batch([image, mask, name, label, width, height],
batch_size=4,
capacity=8000,
num_threads=4,
min_after_dequeue=2000)
else:
images, masks, names, labels, widths, heights = tf.train.batch([image, mask, name, label, width, height],
batch_size=4,
capacity=8000,
num_threads=4)
return images, masks, names, labels, widths, heights
读取tfrecord⽂件中的数据主要是应⽤read_and_decode()这个函数,可以看到其中有个参数是filename_queue,其实我们并不是直接从tfrecord⽂件进⾏读取,⽽是要先利⽤tfrecord⽂件创建⼀个输⼊队列,如本⽂开头所述那样。关于这点,到后⾯真正的测试代码我再介绍。
在read_and_decode()中,⼀上来我们先定义⼀个reader对象,然后使⽤reader得到serialized_example,这是⼀个序列化的对象,接着使⽤tf.parse_single_example()函数对此对象进⾏初步解析。从代码中可以看到,解析时,我们要⽤到之前定义的那些键。对于图像、mask这种转换成字符串的数据,要进⼀步使⽤tf.decode_raw()函数进⾏解析,这⾥要特别注意函数⾥的第⼆个参数,也就是解析后的类型。之前图⽚在转成字符串之前是什么类型的数据,那么这⾥的参数就要填成对应的类型,否则会报错。对于name、label、width、height这样的数据就不⽤再解析了,我们得到的features对象就是个字典,利⽤键就可以拿到对应的值,如代码所⽰。
我注释掉的部分是⽤来做数据增强的,⽐如随机的裁剪与翻转,除了这两种,其他形式的数据增强也可以写在这⾥,读者可以根据⾃⼰的需要,决定是否使⽤各种数据增强⽅式。
函数最后就是使⽤解析出来的数据⽣成batch了。Tensorflow提供了两种⽅式,⼀种是shuffle_batch,这种主要是⽤在训练中,随机选取样本组成batch。另外⼀种就是按照数据在tfrecord中的先后顺序⽣成batch。对于⽣成batch的函数,建议读者去官⽹查看API⽂档进⾏细致了解。这⾥稍微做⼀下介绍,batch的⼤⼩,即batch_size就需要在⽣成batch的函数⾥指定。另外,capacity参数指定数据队列⼀次性能放多少个样本,此参数设置什么值需要视硬件环境⽽定。num_threads参数指定可以开启⼏个线程来向数据队列中填充数据,如果硬件性能不够强,最好设⼩⼀点,否则容易崩。
4. 实例测试
实际使⽤时先指定好我们需要使⽤的tfrecord⽂件:
root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'
然后⽤该tfrecord⽂件创建⼀个输⼊队列:
filename_queue = tf.train.string_input_producer([tfrecord_filename],
num_epochs=3)
这⾥有个参数是num_epochs,指定好之后,Tensorflow⾃然知道如何读取数据,保证在遍历数据集的⼀个epoch中样本不会重复,也知道数据读取何时应该停⽌。
下⾯我将完整的测试代码贴出:
def test_run(tfrecord_filename):
filename_queue = tf.train.string_input_producer([tfrecord_filename],
num_epochs=3)
images, masks, names, labels, widths, heights = read_and_decode(filename_queue)
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
meanvalue = meanfile['mean']
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1):
imgs, msks, nms, labs, wids, heis = sess.run([images, masks, names, labels, widths, heights])
print 'batch' + str(i) + ': '
#print type(imgs[0])
for j in range(4):
print nms[j] + ': ' + str(labs[j]) + ' ' + str(wids[j]) + ' ' + str(heis[j])
img = np.uint8(imgs[j] + meanvalue)
msk = np.uint8(msks[j])
plt.subplot(4,2,j*2+1)
plt.imshow(img)
plt.subplot(4,2,j*2+2)
plt.imshow(msk, vmin=0, vmax=5)
plt.show()
coord.join(threads)
函数中接下来就是利⽤之前定义的read_and_decode()来得到⼀个batch的数据,此后我⼜读⼊了均值⽂件,这是因为之前做了去均值处理,如果要正常显⽰图⽚需要再把均值加回来。
再之后就是建⽴⼀个Tensorflow session,然后初始化对象。这些是Tensorflow基本操作,不再赘述。下⾯的这两句代码⾮常重要,是读取数据必不可少的。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
然后是运⾏sess.run()拿到实际数据,之前只是相当于定义好了,并没有得到真实数值。为了简单起见,我在之后的循环⾥只测试了⼀个batch的数据,关于tfrecord的标准使⽤我也建议读者去官⽹的数据读取部分看看⽰例。循环⾥对数据的各种信息进⾏了展⽰,结果如下:
从图⽚的名字可以看出,数据的确是进⾏了shuffle的,标签、宽度、⾼度、图⽚本⾝以及对应的mask图像也全部展⽰出来了。
测试函数的最后,要使⽤以下两句代码进⾏停⽌,就如同⽂件需要close()⼀样:
以上就是本⽂的全部内容,希望对⼤家的学习有所帮助,也希望⼤家多多⽀持。

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。