TensorFlow⼊门教程(8)读取数据集之Dataset
1、概述
前两讲,我们讲了队列和TFRecord,不知道你们有没有注意到,程序运⾏时,有如下警告(我现在⽤的TensorFlow版本是1.15.1,⽼的版本没有这个警告),
WARNING:tensorflow:From demo4.py:54: string_input_producer (from aining.input) is deprecated and will be removed in a future version.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use
`tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)
[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
这说明我们以前使⽤队列的⽅式已经淘汰了,它推荐我们使⽤tf.data.Dataset的接⼝,tf.data.Dataset是⽐较⾼级的接⼝,使⽤这个接⼝使得处理数据集更简单,这⼀讲,我们就来看这个tf.data.Dataset接⼝怎么使⽤,这也是TensorFlow现在主推的数据集处理⽅式,必须要重点掌握。
环境配置:
操作系统:Win10 64位
显卡:GTX 1080ti
Python:Python3.7
TensorFlow:1.15.0
2、Dataset对象
tf.data.Dataset接⼝是通过创建Dataset对象来⽣成Dataset数据集的,有了Dataset对象,就可以直接做洗牌(shuffle)、设置batch size、复制数据(repeat)等操作。有三种⽅法可以创建Dataset对象,分别
是tf.data.Dataset.from_tensors、tf.data.Dataset.from_tensor_slices和tf.data.Dataset.from_generator。
我们这⾥主要学习tf.data.Dataset.from_tensor_slices的⽤法。
3、tf.data.Dataset.from_tensor_slices
来看⼀个简单的demo就明⽩怎么使⽤tf.data.Dataset.from_tensor_slices了,代码如下,
# 创建迭代器
iterator = dataset.make_one_shot_iterator()
# 从迭代器中获取⼀个数据
_next()
def main(argv=None):
# 创建Dataset对象
dataset = tf.data.Dataset.from_tensor_slices(np.arange(0, 10))
data = get_data(dataset)
# 创建会话
with tf.Session() as sess:
try:
while True:
# 打印获取的data数据
print(sess.run(data))
except:
print('Done..')
if __name__ == '__main__':
tf.app.run()
⾸先,通过tf.data.Dataset.from_tensor_slices接⼝创建⼀个Dataset对象,然后,通过这个对象创建⼀个迭代器,再从迭代器中拿到数据,最后在会话中得到这些数据。运⾏结果如下,
4、数据转换
我们上⾯说过,Dataset可以直接对数据进⾏处理操作,那么,现在就基于上⾯的demo来看看怎么进⾏数据处理。
设置batch size
设置batch size很简单,只要在创建Dataset对象以后,直接设置即可,代码如下,
# 创建迭代器
iterator = dataset.make_one_shot_iterator()
# 从迭代器中获取⼀个数据
_next()
def main(argv=None):
# 创建Dataset对象
dataset = tf.data.Dataset.from_tensor_slices(np.arange(0, 10))
# 设置batch size
dataset = dataset.batch(2)
data = get_data(dataset)
# 创建会话
with tf.Session() as sess:
try:
while True:
# 打印获取的data数据
print(sess.run(data))
except:
print('Done..')
if __name__ == '__main__':
tf.app.run()
运⾏结果,
洗牌shuffle
接着来看对数据进⾏洗牌的操作,跟上⾯设置batch size的⽅式⼀样,所以这⾥就不放全部代码了,只放关键代码即可,代码如下,
# 洗牌操作,其中参数5是指定buffer_size
session如何设置和读取
dataset = dataset.shuffle(5)
运⾏结果,
那么这个buffer_size怎么理解呢?我们画个图来理解,
如上图所⽰,Dataset会根据buffer_size的值创建⼀个⼤⼩为buffer_size的缓冲区Buffer,然后,将所有数据All Data的前buffer_size个数据填充Buffer,
接着,从Buffer随机取⼀个数据输出,⽐如上图中就随机取出了item 3作为输出,那么,原来item 3的位置就会空出来,
此时,就会顺序的从All Data⾥选择⼀条数据填充到这个空出来的Buffer位置,然后再随机从Buffer中抽取⼀个数据输出,如此循环,就可以对数据进⾏洗牌操作。buffer_size越⼤,数据的顺序就会被洗得越乱。如果设置buffer_size为1,就会发现,数据的顺序没被洗乱。
复制数据repeat
接着来看复制数据操作,代码如下,
# 复制操作,其中参数2是复制次数
dataset = peat(2)
运⾏结果,
Map操作
Map操作主要是对数据集的每条数据进⾏指定的操作,⽐如,让数据集的每个数据乘以2,代码如下,
# Map操作,可以对每个数据进⾏指定操作
dataset = dataset.map(lambda x : x * 2)
运⾏结果,
Filter操作
Filter操作可以对数据进⾏过滤,⽐如,过滤掉数据中⼩于5的数,代码如下,
# filter操作,对数据进⾏过滤操作
dataset = dataset.filter(lambda x : x > 4)
运⾏结果,
5、将MNIST数据集以图⽚的形式保存
⽼规矩,还是以MNIST数据集为例,跟上⼀讲⼀样,将数据保存成图⽚的形式,如下图所⽰,
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论