学习keras 0x08–tf.data四种迭代器

2,840次阅读
没有评论

共计 7630 个字符,预计需要花费 20 分钟才能阅读完成。

tf.data四种迭代器

日常使用中单次迭代器应该是是用最多的,一般情况下数据量都是比较大,遍历一遍就搞定了。还是需要了解一下其他的迭代器,其实也是有相应的场合会需要这么去处理。

MNIST的经典例子

本篇博客结合 mnist 的经典例子,针对不同的源数据:csv数据和tfrecord数据,分别运用 tf.data.TextLineDataset() 和 tf.data.TFRecordDataset() 创建不同的 Dataset 并运用四种不同的 Iterator ,分别是 单次,可初始化,可重新初始化,以及可馈送迭代器 的方式实现对源数据的预处理工作。

  • make_one_shot_iterator
  • make_initializable_iterator
  • Reinitializable iterator
  • Feedable iterator

tf.data.TFRecordDataset() & make_one_shot_iterator()

tf.data.TFRecordDataset() 输入参数直接是后缀名为tfrecords的文件路径,正因如此,即可解决数据量过大,导致无法单机训练的问题。本篇博客中,文件路径即为/Users/***/Desktop/train_output.tfrecords,此处是我自己电脑上的路径,大家可以 根据自己的需要修改为对应的文件路径。 make_one_shot_iterator() 即为单次迭代器,是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。

单次迭代表示数据集是只迭代一次,但是你仍然可以将数据重复多个epoch的,只是重复之后也是只遍历一次完成数据的处理。

配合 MNIST数据集以及tf.data.TFRecordDataset(),实现代码如下。

# Validate tf.data.TFRecordDataset() using make_one_shot_iterator()import tensorflow as tf
import numpy as np
num_epochs = 2
num_class = 10
sess = tf.Session()
# Use `tf.parse_single_example()` to extract data from a `tf.Example`# protocol buffer, and perform any additional per-record preprocessing.def parser(record): keys_to_features = { "image_raw": tf.FixedLenFeature((), tf.string, default_value=""), "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)), "label": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } parsed = tf.parse_single_example(record, keys_to_features) # Parse the string into an array of pixels corresponding to the imageimages = tf.decode_raw(parsed["image_raw"],tf.uint8) images = tf.reshape(images,[28,28,1]) labels = tf.cast(parsed['label'], tf.int32) labels = tf.one_hot(labels,num_class) pixels = tf.cast(parsed['pixels'], tf.int32) print("IMAGES",images) print("LABELS",labels) return {"image_raw": images}, labels
filenames = ["/Users/***/Desktop/train_output.tfrecords"]
# replace the filenames with your own pathdataset = tf.data.TFRecordDataset(filenames)
print("DATASET",dataset)
# Use `Dataset.map()` to build a pair of a feature dictionary and a label# tensor for each example.dataset = dataset.map(parser)
print("DATASET_1",dataset)
dataset = dataset.shuffle(buffer_size=10000)
print("DATASET_2",dataset)
dataset = dataset.batch(32)
print("DATASET_3",dataset)
dataset = dataset.repeat(num_epochs)
print("DATASET_4",dataset)
iterator = dataset.make_one_shot_iterator()
# `features` is a dictionary in which each value is a batch of values for# that feature; `labels` is a batch of labels.features, labels = iterator.get_next()
print("FEATURES",features)
print("LABELS",labels)
print("SESS_RUN_LABELS \n",sess.run(labels))

tf.data.TFRecordDataset() & Initializable iterator

make_initializable_iterator() 为可初始化迭代器,运用此迭代器首先需要先运行显式 iterator.initializer 操作,然后才能使用。并且,可运用 可初始化迭代器实现训练集和验证集的切换。 配合 MNIST数据集 实现代码如下。

这里公用了一套dataset处理流程,对于数据处理方式一样的数据集的确可以使用一套方式来处理,可初始化的迭代器表示可以使用不同的数据源来初始化该迭代器并且实现数据迭代的功能。

这里的迭代器可初始化意思就是你可以借助placeholder传参,每次传递不同的参数就得到不同的数据,数据处理的方式都是一致的,一定程度上有一定的可定制化。

# Validate tf.data.TFRecordDataset() using make_initializable_iterator()
# In order to switch between train and validation datanum_epochs = 2
num_class = 10
def parser(record): 
	keys_to_features = { "image_raw": tf.FixedLenFeature((), tf.string, default_value=""), "pixels": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)), "label": tf.FixedLenFeature((), tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } 
	parsed = tf.parse_single_example(record, keys_to_features) # Parse the string into an array of pixels corresponding to the image
	images = tf.decode_raw(parsed["image_raw"],tf.uint8) images = tf.reshape(images,[28,28,1]) labels = tf.cast(parsed['label'], tf.int32) 
	labels = tf.one_hot(labels,10) 
	pixels = tf.cast(parsed['pixels'], tf.int32) 
	print("IMAGES",images) 
	print("LABELS",labels) 
	return {"image_raw": images}, labels
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parser) # Parse the record into tensors# print("DATASET",dataset)dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
print("DATASET",dataset)
iterator = dataset.make_initializable_iterator()
features, labels = iterator.get_next()
print("ITERATOR",iterator)
print("FEATURES",features)
print("LABELS",labels)
# Initialize `iterator` with training data.
training_filenames = ["/Users/honglan/Desktop/train_output.tfrecords"]
# replace the filenames with your own path
sess.run(iterator.initializer,feed_dict={filenames: training_filenames})
print("TRAIN\n",sess.run(labels))
# print(sess.run(features))# Initialize `iterator` with validation data.
validation_filenames = ["/Users/honglan/Desktop/val_output.tfrecords"]
# replace the filenames with your own path
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
print("VAL\n",sess.run(labels))

tf.data.TextLineDataset() & Reinitializable iterator

可重复初始化的迭代器,这个与之前的可初始化的迭代器有什么区别?

可重新初始化迭代器可以通过多个不同的 Dataset 对象进行初始化。例如,您可能有一个训练输入管道,它会对输入图片进行随机扰动来改善泛化;还有一个验证输入管道,它会评估对未修改数据的预测。这些管道通常会使用不同的 Dataset 对象,这些对象具有相同的结构(即每个组件具有相同类型和兼容形状)。

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

从上面可以看出在循环的20次中,验证集的迭代器一直在不断的初始化迭代,可以理解为每运行一个epoch然后需要遍历所有的验证集,然后验证相应的效果。

tf.data.TextLineDataset() & Feedable iterator.

可馈送的迭代器,这个算是最复杂的迭代器,与之前的介绍的可重复初始化的迭代器不一样,这个在切换数据的时候不需要初始化操作,下面的例子就是可以充分的说明。我们定义了一个无线循环的训练集,然后我们需要遍历它,迭代方式只要迭代一次就行了,因为它是无线循环的嘛,所以在下面的while循环中就直接使用run不断的获取下一个数据即可

在我看来这种不需要初始化的情况一般都是在于大批量数据处理的情况下(无论是原始数据或者是经过epoch重复处理之后的),这个在训练的时候只要依此遍历就好了,不需要重复初始化。从下面的代码里面可以看出,这个可馈送的迭代器实际上充分利用了之前描述的迭代器,比如训练集使用的是单次迭代器,验证集使用的可初始化迭代器。所以这种数据集切换无需重新初始化只是一个相对的概念。

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)

# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())

# Loop forever, alternating between training and validation.
while True:
  # Run 200 steps using the training dataset. Note that the training dataset is
  # infinite, and we resume from where we left off in the previous `while` loop
  # iteration.
  for _ in range(200):
    sess.run(next_element, feed_dict={handle: training_handle})

  # Run one pass over the validation dataset.
  sess.run(validation_iterator.initializer)
  for _ in range(50):
    sess.run(next_element, feed_dict={handle: validation_handle})
正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2019-09-16发表,共计7630字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码