tf.estimator 基础

7,609次阅读
一条评论

共计 34281 个字符,预计需要花费 86 分钟才能阅读完成。

Estimator

框架结构

在介绍Estimator之前需要对它在TensorFlow这个大框架的定位有个大致的认识,如下图示:

tf.estimator 基础

可以看到Estimator是属于High level的API,而Mid-level API分别是:

  • Layers:用来构建网络结构
  • Datasets: 用来构建数据读取pipeline
  • Metrics:用来评估网络性能

可以看到如果使用Estimator,我们只需要关注这三个部分即可,而不用再关心一些太细节的东西,另外也不用再使用烦人的Session了。

Estimator使用步骤

  • 创建一个或多个输入函数,即input_fn
  • 定义模型的特征列,即feature_columns ,不用特征列也是可以的。
  • 实例化 Estimator,指定特征列和各种超参数。
  • 在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源。(train, evaluate, predict)
tf.estimator 基础

模块拆分

数据输入模块 input_fn

The function should construct and return one of the following: * A tf.data.Dataset object: Outputs of Dataset object must be a tuple (features, labels) with same constraints as below.

也就是说input_fn的数据还是有明确的要求,正常情况下我们需要两个输入:

  • 训练的输入
  • 验证的输入
def csv_input_fn(files_name_pattern, mode=tf.estimator.ModeKeys.EVAL, 
                 skip_header_lines=0, 
                 num_epochs=None, 
                 batch_size=200):
    
    shuffle = True if mode == tf.estimator.ModeKeys.TRAIN else False
        
    num_threads = multiprocessing.cpu_count() if MULTI_THREADING else 1
     
    print("")
    print("* data input_fn:")
    print("================")
    print("Input file(s): {}".format(files_name_pattern))
    print("Batch size: {}".format(batch_size))
    print("Epoch Count: {}".format(num_epochs))
    print("Mode: {}".format(mode))
    print("Thread Count: {}".format(num_threads))
    print("Shuffle: {}".format(shuffle))
    print("================")
    print("")

    file_names = tf.matching_files(files_name_pattern)
    dataset = data.TextLineDataset(filenames=file_names)
    
    dataset = dataset.skip(skip_header_lines)
    
    if shuffle:
        dataset = dataset.shuffle(buffer_size=2 * batch_size + 1)
    
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda csv_row: parse_csv_row(csv_row), 
                          num_parallel_calls=num_threads)
    
    if PROCESS_FEATURES:
        dataset = dataset.map(lambda features, target: (process_features(features), target), 
                              num_parallel_calls=num_threads)
        
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    
    features, target = iterator.get_next()
    return features, target

model_fn

模型函数一般定义如下:

def model_fn(
   features,    # This is batch_features from input_fn,`Tensor` or dict of `Tensor` (depends on data passed to `fit`).
   labels,     # This is batch_labels from input_fn
   mode,      # An instance of tf.estimator.ModeKeys
   params,      # Additional configuration
   config=None
   ):
  • 前两个参数是从输入函数中返回的特征和标签批次;也就是说,features 和 labels 是模型将使用的数据。
  • params 是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params['n_classes']来定义最终输出节点的个数等。
  • mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN / EVAL / PREDICT 来定义。另外通过观察DNNClassifier源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。例如当你调用estimator.train(...)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN

model_fn需要对于不同的模式提供不同的处理方式,并且都需要返回一个tf.estimator.EstimatorSpec的实例。

咋听起来可能有点不知所云,大白话版本就是:模型有训练,验证和测试三种阶段,而且对于不同模式,对数据有不同的处理方式。例如在训练阶段,我们需要将数据喂给模型,模型基于输入数据给出预测值,然后我们在通过预测值和真实值计算出loss,最后用loss更新网络参数,而在评估阶段,我们则不需要反向传播更新网络参数,换句话说,model_fn需要对三种模式设置三套代码。

那么三个阶段的输入要是不一样的:

  • train 模式下需要输入loss 和 train_op优化目标
  • evaluate 需要指定 loss 和 metric 测评模型
  • predict 模式也是需要指定输出的节点加载checkpoint 就好了

另外model_fn需要返回什么东西呢?Estimator规定model_fn需要返tf.estimator.EstimatorSpec,这样它才好更具一般化的进行处理。

EstimatorSpec

这个类是Estimator的参数,实际上是定义了Estimator执行任务时的具体详细情况。比如 train 模式如何去做训练。

tf.estimator.EstimatorSpec(
    mode, predictions=None, loss=None, train_op=None, eval_metric_ops=None,
    export_outputs=None, training_chief_hooks=None, training_hooks=None,
    scaffold=None, evaluation_hooks=None, prediction_hooks=None
)

参数说明

  • mode:一个ModeKeys,指定是training(训练)、evaluation(计算)还是prediction(预测).
  • predictions:Predictions Tensor or dict of Tensor.
  • loss:Training loss Tensor. Must be either scalar, or with shape [1].
  • train_op:适用于训练的步骤.
  • eval_metric_ops: Dict of metric results keyed by name. The values of the dict can be one of the following:
    • (1) instance of Metric class.
    • (2) Results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

      现在SDK里使用的是第二种方式,给定name和对应的update_op

不同模式需要传入不同参数

根据mode的值的不同,需要不同的参数,即:

  • 对于mode == ModeKeys.TRAIN:必填字段是loss和train_op.
  • 对于mode == ModeKeys.EVAL:必填字段是loss和 metrics
  • 对于mode == ModeKeys.PREDICT:必填字段是predictions.

上面的参数说明看起来还是一头雾水,下面给出例子帮助理解:

最简单的情况: predict

只需要传入modepredictions

# Compute predictions.predicted_classes = tf.argmax(logits, 1)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
'class_ids': predicted_classes[:, tf.newaxis],
'probabilities': tf.nn.softmax(logits),
'logits': logits,
}
return tf.estimator.EstimatorSpec(mode, predictions=predictions)

评估模式:eval

需要传入mode,loss,eval_metric_ops

如果调用 Estimator 的 evaluate 方法,则 model_fn 会收到 mode = ModeKeys.EVAL。在这种情况下,模型函数必须返回一个包含模型损失和一个或多个指标(可选)的 tf.estimator.EstimatorSpec。

loss示例如下:

loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

TensorFlow提供了一个指标模块tf.metrics来计算常用的指标,这里以accuracy为例:

# Compute evaluation metrics.
accuracy = tf.metrics.accuracy(labels=labels,
                               predictions=predicted_classes,
                               name='acc_op')

返回方式如下:

metrics = {'accuracy': accuracy}

if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(
        mode, loss=loss, eval_metric_ops=metrics)

训练模式:train

需要传入mode,loss,train_op

loss同eval模式:

# Compute loss.
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

train_op示例:

optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(loss,global_step=tf.train.get_global_step())

返回值:

return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

Hook

中文直接翻译叫做钩子,可以直白的理解就是在你执行程序的时候满足一定的触发条件而执行的任务。可以理解为设计模式里面的观察者模式,它们一直在观察这程序执行呢!

在介绍所有的功能hook之前,先介绍它们的父类。

tf.train.SessionRunHook

class SessionRunHook(object):
  """Hook to extend calls to MonitoredSession.run()."""
 
  def begin(self):
    """再创建会话之前调用
    调用begin()时,default graph会被创建,
    可在此处向default graph增加新op,begin()调用后,default graph不能再被修改
    """
    pass
 
  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """tf.Session被创建后调用
    调用后会指示所有的Hooks有一个新的会话被创建
    Args:
      session: A TensorFlow Session that has been created.
      coord: A Coordinator object which keeps track of all threads.
    """
    pass
 
  def before_run(self, run_context):  # pylint: disable=unused-argument
    """调用在每个sess.run()执行之前
    可以返回一个tf.train.SessRunArgs(op/tensor),在即将运行的会话中加入这些op/tensor;
    加入的op/tensor会和sess.run()中已定义的op/tensor合并,然后一起执行;
    Args:
      run_context: A `SessionRunContext` object.
    Returns:
      None or a `SessionRunArgs` object.
    """
    return None
  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):  # pylint: disable=unused-argument
    """调用在每个sess.run()之后
    参数run_values是befor_run()中要求的op/tensor的返回值;
    可以调用run_context.qeruest_stop()用于停止迭代
    sess.run抛出任何异常after_run不会被调用
    Args:
      run_context: A `SessionRunContext` object.
      run_values: A SessionRunValues object.
    """
    pass
 
  def end(self, session):  # pylint: disable=unused-argument
    """在会话结束时调用
    end()常被用于Hook想要执行最后的操作,如保存最后一个checkpoint
    如果sess.run()抛出除了代表迭代结束的OutOfRange/StopIteration异常外,
    end()不会被调用
    Args:
      session: A TensorFlow Session that will be soon closed.
    """
    pass

概括起来就是存在5种情况:

  • 创建session会话之前
  • 创建session会话之后

在这五种情况下分别执行相应的一些操作,这些都是在你定义的一个HOOK下要执行的。当然很多时候你不一定五种情况下都写好相应的代码,也许你只是在某个阶段做指定的事情,那么你只要复写对应的func就好了 。

Tensorflow 自己就定义了很多的 hook

  • …..

下面就拿一个最简单的 LoggingHook来解释一下

class LoggingTensorHook(session_run_hook.SessionRunHook):
  """Prints the given tensors every N local steps, every N seconds, or at end.
  The tensors will be printed to the log, with `INFO` severity. If you are not
  seeing the logs, you might want to add the following line after your imports:
  ```python
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
  ```
  Note that if `at_end` is True, `tensors` should not include any tensor
  whose evaluation produces a side effect such as consuming additional inputs.
  在指定的步数或者时间内打印日志,一般情况下我们都是指定step打印相应的日志,比如训练的loss,auc
  之类的。
  """

  def __init__(self,
               tensors,
               every_n_iter=None,
               every_n_secs=None,
               at_end=False,
               formatter=None):
    """Initializes a `LoggingTensorHook`.
    Args:
      tensors: `dict` that maps string-valued tags to tensors/tensor names, or
        `iterable` of tensors/tensor names.
      every_n_iter: `int`, print the values of `tensors` once every N local
        steps taken on the current worker.
      every_n_secs: `int` or `float`, print the values of `tensors` once every N
        seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
        provided.
      at_end: `bool` specifying whether to print the values of `tensors` at the
        end of the run.
      formatter: function, takes dict of `tag`->`Tensor` and returns a string.
        If `None` uses default printing all tensors.
    Raises:
      ValueError: if `every_n_iter` is non-positive.
    """
    # 标志位是否只在结束时打印日志
    only_log_at_end = (
        at_end and (every_n_iter is None) and (every_n_secs is None))
    if (not only_log_at_end and
        (every_n_iter is None) == (every_n_secs is None)):
      raise ValueError(
          "either at_end and/or exactly one of every_n_iter and every_n_secs "
          "must be provided.")
    # 校验  every_n_iter 和  every_n_iter 参数
    if every_n_iter is not None and every_n_iter <= 0:
      raise ValueError("invalid every_n_iter=%s." % every_n_iter)
    # tensors 类型校验,不是dict转化成dict
    if not isinstance(tensors, dict):
      self._tag_order = tensors
      tensors = {item: item for item in tensors}
    else:
      self._tag_order = sorted(tensors.keys())
    self._tensors = tensors
    self._formatter = formatter
    self._timer = (
        NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
            every_secs=every_n_secs, every_steps=every_n_iter))
    self._log_at_end = at_end

  def begin(self):
    # 重置计数器
    self._timer.reset()
    self._iter_count = 0
    # Convert names to tensors if given 
    # tensor 类型转化
    self._current_tensors = {
        tag: _as_graph_element(tensor)
        for (tag, tensor) in self._tensors.items()
    }

  def before_run(self, run_context):  # pylint: disable=unused-argument
    # 判断是否要进行打印了 
    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
    if self._should_trigger:
      return SessionRunArgs(self._current_tensors)
    else:
      return None

  def _log_tensors(self, tensor_values):
    # 打印对应 tensor 的信息
    original = np.get_printoptions()
    np.set_printoptions(suppress=True)
    # 计算消耗的时间
    elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
    if self._formatter:
      logging.info(self._formatter(tensor_values))
    else:
      stats = []
      for tag in self._tag_order:
        stats.append("%s = %s" % (tag, tensor_values[tag]))
      if elapsed_secs is not None:
        logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
      else:
        logging.info("%s", ", ".join(stats))
    np.set_printoptions(**original)

  def after_run(self, run_context, run_values):
    _ = run_context
    if self._should_trigger:
      self._log_tensors(run_values.results)

    self._iter_count += 1

  def end(self, session):
    if self._log_at_end:
      values = session.run(self._current_tensors)
      self._log_tensors(values)

SessRunContext/SessRunValues/SessRunArgs

这三个类tf.train.SessRunContext/tf.train.SessRunValues/tf.train.SessRunArgs服务于sess.run();

tf.train.SessRunContext/tf.train.SessRunArgs提供会话运行所需的信息,

tf.train.SessRunValues保存会话运行的结果

(1)tf.train.SessRunArgs类

提供给会话运行的参数,与sess.run()参数定义一样:

fethes,feeds,option

(2)tf.train.SessRunValues

用于保存sess.run()的结果,

其中resluts是sess.run()返回值中对应于SessRunArgs()的返回值,

(3)tf.train.SessRunContext

SessRunContext包含sess.run()所需的一切信息

Estimator 如何跑起来的?

首先是从调用函数开始说起,调用函数是train_and_evaluate

def train_and_evaluate(estimator, train_spec, eval_spec):
  """Train and evaluate the `estimator`.

  This utility function trains, evaluates, and (optionally) exports the model by
  using the given `estimator`. All training related specification is held in
  `train_spec`, including training `input_fn` and training max steps, etc. All
  evaluation and export related specification is held in `eval_spec`, including
  evaluation `input_fn`, steps, etc.

  This utility function provides consistent behavior for both local
  (non-distributed) and distributed configurations. The default distribution
  configuration is parameter server-based between-graph replication. For other
  types of distribution configurations such as all-reduce training, please use
  [DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute).

  Overfitting: In order to avoid overfitting, it is recommended to set up the
  training `input_fn` to shuffle the training data properly.

  Stop condition: In order to support both distributed and non-distributed
  configuration reliably, the only supported stop condition for model
  training is `train_spec.max_steps`. If `train_spec.max_steps` is `None`, the
  model is trained forever. *Use with care* if model stop condition is
  different. For example, assume that the model is expected to be trained with
  one epoch of training data, and the training `input_fn` is configured to throw
  `OutOfRangeError` after going through one epoch, which stops the
  `Estimator.train`. For a three-training-worker distributed configuration, each
  training worker is likely to go through the whole epoch independently. So, the
  model will be trained with three epochs of training data instead of one epoch.

  Example of local (non-distributed) training:

  ```python
  # Set up feature columns.
  categorial_feature_a = categorial_column_with_hash_bucket(...)
  categorial_feature_a_emb = embedding_column(
      categorical_column=categorial_feature_a, ...)
  ...  # other feature columns

  estimator = DNNClassifier(
      feature_columns=[categorial_feature_a_emb, ...],
      hidden_units=[1024, 512, 256])

  # Or set up the model directory
  #   estimator = DNNClassifier(
  #       config=tf.estimator.RunConfig(
  #           model_dir='/my_model', save_summary_steps=100),
  #       feature_columns=[categorial_feature_a_emb, ...],
  #       hidden_units=[1024, 512, 256])

  # Input pipeline for train and evaluate.
  def train_input_fn(): # returns x, y
    # please shuffle the data.
    pass
  def eval_input_fn(): # returns x, y
    pass

  train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
  eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)

  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  ```
  Note that in current implementation `estimator.evaluate` will be called
  multiple times. This means that evaluation graph (including eval_input_fn)
  will be re-created for each `evaluate` call. `estimator.train` will be called
  only once.

  Example of distributed training:

  Regarding the example of distributed training, the code above can be used
  without a change (Please do make sure that the `RunConfig.model_dir` for all
  workers is set to the same directory, i.e., a shared file system all workers
  can read and write). The only extra work to do is setting the environment
  variable `TF_CONFIG` properly for each worker correspondingly.

  Also see
  [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).

  Setting environment variable depends on the platform. For example, on Linux,
  it can be done as follows (`` is the shell prompt):

  ``` TF_CONFIG='<replace_with_real_content>' python train_model.py
  ```

  For the content in `TF_CONFIG`, assume that the training cluster spec looks
  like:

  ```
  cluster = {"chief": ["host0:2222"],
             "worker": ["host1:2222", "host2:2222", "host3:2222"],
             "ps": ["host4:2222", "host5:2222"]}
  ```

  Example of `TF_CONFIG` for chief training worker (must have one and only one):

  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "chief", "index": 0}
  }'
  ```
  Note that the chief worker also does the model training job, similar to other
  non-chief training workers (see next paragraph). In addition to the model
  training, it manages some extra work, e.g., checkpoint saving and restoring,
  writing summaries, etc.

  Example of `TF_CONFIG` for non-chief training worker (optional, could be
  multiple):

  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "worker", "index": 0}
  }'
  ```
  where the `task.index` should be set as 0, 1, 2, in this example, respectively
  for non-chief training workers.

  Example of `TF_CONFIG` for parameter server, aka ps (could be multiple):

  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "ps", "index": 0}
  }'
  ```
  where the `task.index` should be set as 0 and 1, in this example, respectively
  for parameter servers.

  Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is
  not part of the training cluster. There could be only one. It is used for
  model evaluation.

  ```
  # This should be a JSON string, which is set as environment variable. Usually
  # the cluster manager handles that.
  TF_CONFIG='{
      "cluster": {
          "chief": ["host0:2222"],
          "worker": ["host1:2222", "host2:2222", "host3:2222"],
          "ps": ["host4:2222", "host5:2222"]
      },
      "task": {"type": "evaluator", "index": 0}
  }'
  ```

  When `distribute` or `experimental_distribute.train_distribute` and
  `experimental_distribute.remote_cluster` is set, this method will start a
  client running on the current host which connects to the `remote_cluster` for
  training and evaluation.

  Args:
    estimator: An `Estimator` instance to train and evaluate.
    train_spec: A `TrainSpec` instance to specify the training specification.
    eval_spec: A `EvalSpec` instance to specify the evaluation and export
      specification.

  Returns:
    A tuple of the result of the `evaluate` call to the `Estimator` and the
    export results using the specified `ExportStrategy`.
    Currently, the return value is undefined for distributed training mode.

  Raises:
    ValueError: if environment variable `TF_CONFIG` is incorrectly set.
  """
  _assert_eval_spec(eval_spec)  # fail fast if eval_spec is invalid.

  executor = _TrainingExecutor(
      estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
  config = estimator.config

  # If `distribute_coordinator_mode` is set and running in distributed
  # environment, we run `train_and_evaluate` via distribute coordinator.
  if distribute_coordinator_training.should_run_distribute_coordinator(config):
    logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
    distribute_coordinator_training.train_and_evaluate(
        estimator, train_spec, eval_spec, _TrainingExecutor)
    return

  if (config.task_type == run_config_lib.TaskType.EVALUATOR and
      config.task_id > 0):
    raise ValueError(
        'For distributed training, there can only be one `evaluator` task '
        '(with task id 0).  Given task id {}'.format(config.task_id))

  return executor.run()

上面的函数中_TrainingExecutor.run()是最终的返回值,继续下钻 现看下 _TrainingExecutor 的定义

初始化函数里各种校验传入的参数


class _TrainingExecutor(object):
  """The executor to run `Estimator` training and evaluation.

  This implementation supports both distributed and non-distributed (aka local)
  training and evaluation based on the setting in `tf.estimator.RunConfig`.
  """

  def __init__(self,
               estimator,
               train_spec,
               eval_spec,
               train_hooks=None,
               continuous_eval_listener=None):
    if not isinstance(estimator,
                      (estimator_lib.Estimator, estimator_lib.EstimatorV2)):
      raise TypeError(
          '`estimator` must have type `tf.estimator.Estimator`. '
          'Got: {}'.format(type(estimator)))
    self._estimator = estimator

    if not isinstance(train_spec, TrainSpec):
      raise TypeError(
          '`train_spec` must have type `tf.estimator.TrainSpec`. '
          'Got: {}'.format(type(train_spec)))
    self._train_spec = train_spec

    if eval_spec and not isinstance(eval_spec, EvalSpec):
      raise TypeError('`eval_spec` must be either `None` or have type '
                      '`tf.estimator.EvalSpec`. Got: {}'.format(
                          type(eval_spec)))
    self._eval_spec = eval_spec

    self._train_hooks = _validate_hooks(train_hooks)

    if (continuous_eval_listener and
        not isinstance(continuous_eval_listener, _ContinuousEvalListener)):
      raise TypeError('`continuous_eval_listener` must have type '
                      '`_ContinuousEvalListener`.')
    self._continuous_eval_listener = (
        continuous_eval_listener or _ContinuousEvalListener())

再看下他的run 方法

这个run方法在分布式和单机下执行的方法不尽相同,该方法预定义了,取决于当前定于runconfig时确定的类型比如 cluster和local

def run(self):
    """Executes the run_foo for task type `foo`.

    `_TrainingExecutor` predefines the procedure for task type 'chief',
    'worker', 'ps', and 'evaluator'. For task type `foo`, the corresponding
    procedure is `run_foo'. This `run` method invoke the procedure base on the
    `RunConfig.task_type`.

    Returns:
      A tuple of the result of the `evaluate` call to the `Estimator` and the
      export results using the specified `ExportStrategy`.
      Currently undefined for distributed training mode.

    Raises:
      ValueError: if the estimator.config is mis-configured.
    """
    config = self._estimator.config

    if (not config.cluster_spec and
        config.task_type != run_config_lib.TaskType.EVALUATOR):
      logging.info('Running training and evaluation locally (non-distributed).')
      return self.run_local()

    # Distributed case.
    if not config.task_type:
      # TODO(xiejw): Improve the error message about how to set the TF_CONFIG
      # correctly.
      raise ValueError(
          '`estimator.config` must have task_type set. This usually means '
          'TF_CONFIG environment is not set correctly.')

    if config.task_type == 'local':
      raise ValueError(
          '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
          '`task` properties in TF_CONFIG absent triggers train and evaluate '
          '`Estimator` locally (non-distributed).')

    # For task type foo, call executor.run_foo.
    available_tasks = [
        x for x in dir(self)
        if x.startswith('run_') and x != 'run_local' and
        callable(getattr(self, x))
    ]
    task_to_run = 'run_' + config.task_type
    if task_to_run not in available_tasks:
      raise ValueError(
          'Task type {} is not supported. Supported task types are {}'.format(
              config.task_type, [x[len('run_'):] for x in available_tasks]))
    getattr(self, task_to_run)()

以 run_Local 来说明

def run_local(self):
    """Runs training and evaluation locally (non-distributed)."""
    _assert_eval_spec(self._eval_spec)

    train_hooks = list(self._train_spec.hooks) + list(self._train_hooks)
    logging.info('Start train and evaluate loop. The evaluate will happen '
                 'after every checkpoint. Checkpoint frequency is determined '
                 'based on RunConfig arguments: save_checkpoints_steps {} or '
                 'save_checkpoints_secs {}.'.format(
                     self._estimator.config.save_checkpoints_steps,
                     self._estimator.config.save_checkpoints_secs))

    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
                                             self._train_spec.max_steps)
    # 监听 验证程序的 listenser
    listener_for_eval = _NewCheckpointListenerForEvaluate(
        evaluator, self._eval_spec.throttle_secs,
        self._continuous_eval_listener)
    saving_listeners = [listener_for_eval]
    # 此处调用 estimator 的 train 函数执行模型训练
    self._estimator.train(
        input_fn=self._train_spec.input_fn,
        max_steps=self._train_spec.max_steps,
        hooks=train_hooks,
        saving_listeners=saving_listeners)

    eval_result = listener_for_eval.eval_result or _EvalResult(
        status=_EvalStatus.MISSING_CHECKPOINT)
    return eval_result.metrics, listener_for_eval.export_results

继续下钻看代码

def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
    """Trains a model given training data `input_fn`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
        See [Premade Estimators](
        https://tensorflow.org/guide/premade_estimators#create_input_functions)
        for more information. The function should construct and return one of
        the following:  * A
        `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
        `(features, labels)` with same constraints as below. * A tuple
        `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
        of string feature name to `Tensor` and `labels` is a `Tensor` or a
        dictionary of string label name to `Tensor`. Both `features` and
        `labels` are consumed by `model_fn`. They should satisfy the expectation
        of `model_fn` from inputs.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      steps: Number of steps for which to train the model. If `None`, train
        forever or train until `input_fn` generates the `tf.errors.OutOfRange`
        error or `StopIteration` exception. `steps` works incrementally. If you
        call two times `train(steps=10)` then training occurs in total 20 steps.
        If `OutOfRange` or `StopIteration` occurs in the middle, training stops
        before 20 steps. If you don't want to have incremental behavior please
        set `max_steps` instead. If set, `max_steps` must be `None`.
      max_steps: Number of total steps for which to train model. If `None`,
        train forever or train until `input_fn` generates the
        `tf.errors.OutOfRange` error or `StopIteration` exception. If set,
        `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the
        middle, training stops before `max_steps` steps. Two calls to
        `train(steps=100)` means 200 training iterations. On the other hand, two
        calls to `train(max_steps=100)` means that the second call will not do
        any iteration since first call did all 100 steps.
      saving_listeners: list of `CheckpointSaverListener` objects. Used for
        callbacks that run immediately before or after checkpoint savings.

    Returns:
      `self`, for chaining.

    Raises:
      ValueError: If both `steps` and `max_steps` are not `None`.
      ValueError: If either `steps` or `max_steps <= 0`.
    """
    _estimator_api_gauge.get_cell('train').set(True)
    if self.config.task_type in (run_config.TaskType.EVALUATOR,
                                 run_config.TaskType.PS):
      raise ValueError(
          'Train has been called wrong configuration. Please use '
          'tf.estimator.train_and_evaluate which calls proper API according '
          'to given configuration. Current configuration: {}.'.format(
              self.config))

    with context.graph_mode():
      if (steps is not None) and (max_steps is not None):
        raise ValueError('Can not provide both steps and max_steps.')
      if steps is not None and steps <= 0:
        raise ValueError('Must specify steps > 0, given: {}'.format(steps))
      if max_steps is not None and max_steps <= 0:
        raise ValueError(
            'Must specify max_steps > 0, given: {}'.format(max_steps))

      if max_steps is not None:
        start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
        if max_steps <= start_step:
          logging.info('Skipping training since max_steps has already saved.')
          return self

      hooks = _check_hooks_type(hooks)
      hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))

      saving_listeners = _check_listeners_type(saving_listeners)
      loss = self._train_model(input_fn, hooks, saving_listeners)
      logging.info('Loss for final step: %s.', loss)
      return self

这里到 _train_model

def _train_model(self, input_fn, hooks, saving_listeners):
    if self._train_distribution:
      return self._train_model_distributed(input_fn, hooks, saving_listeners)
    else:
      return self._train_model_default(input_fn, hooks, saving_listeners)

  def _train_model_default(self, input_fn, hooks, saving_listeners):
    """Initiate training with `input_fn`, without `DistributionStrategies`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
        for callbacks that run immediately before or after checkpoint savings.

    Returns:
      Loss from training
    """
    worker_hooks = []
    # 使用的是默认的图
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)

      # Skip creating a read variable if _create_and_assert_global_step
      # returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
      if global_step_tensor is not None:
        training_util._get_or_create_global_step_read(g)  # pylint: disable=protected-access
      # data interator 数据流获取 input_fn 需要的数据
      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(
              input_fn, ModeKeys.TRAIN))
      worker_hooks.extend(input_hooks)
      # 调用 model_fn 就是我们定义的的,此时对应的 MODE都是定义好的传进去的
      estimator_spec = self._call_model_fn(
          features, labels, ModeKeys.TRAIN, self.config)
      global_step_tensor = training_util.get_global_step(g)
      return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                             hooks, global_step_tensor,
                                             saving_listeners)

接下来就是重中之重了 _train_with_estimator_spec 这个函数

def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                 global_step_tensor, saving_listeners):
    """Train a model with the given Estimator Spec."""
    # 判断是否进行 warmup 
    if self._warm_start_settings:
      logging.info('Warm-starting with WarmStartSettings: %s' %
                   (self._warm_start_settings,))
      warm_starting_util.warm_start(*self._warm_start_settings)
    # Check if the user created a loss summary, and add one if they didn't.
    # We assume here that the summary is called 'loss'. If it is not, we will
    # make another one with the name 'loss' to ensure it shows up in the right
    # graph in TensorBoard.
    if not any([x.op.name == 'loss'
                for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
      summary.scalar('loss', estimator_spec.loss)
    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
    worker_hooks.extend(hooks)
    worker_hooks.append(
        training.NanTensorHook(estimator_spec.loss)
    )
    if self._config.log_step_count_steps is not None:
      worker_hooks.append(
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=self._config.log_step_count_steps)
      )
    #扩展 Estimatorspec 里面定义的hook,这些都是我们在外部定义好的
    worker_hooks.extend(estimator_spec.training_hooks)

    if not (estimator_spec.scaffold.saver or
            ops.get_collection(ops.GraphKeys.SAVERS)):
      ops.add_to_collection(
          ops.GraphKeys.SAVERS,
          training.Saver(
              sharded=True,
              max_to_keep=self._config.keep_checkpoint_max,
              keep_checkpoint_every_n_hours=(
                  self._config.keep_checkpoint_every_n_hours),
              defer_build=True,
              save_relative_paths=True))

    if (self._config.cluster_spec and type(
        self._train_distribution).__name__ in ('CollectiveAllReduceStrategy',
                                               'CollectiveAllReduceStrategyV1',
                                               'MultiWorkerMirroredStrategy')):
      return self._train_with_estimator_spec_distributed(
          estimator_spec, worker_hooks, saving_listeners)

    chief_hooks = []
    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
    saver_hooks = [
        h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
    if (self._config.save_checkpoints_secs or
        self._config.save_checkpoints_steps):
      if not saver_hooks:
        chief_hooks = [
            training.CheckpointSaverHook(
                self._model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=estimator_spec.scaffold)
        ]
        saver_hooks = [chief_hooks[0]]
    if saving_listeners:
      if not saver_hooks:
        raise ValueError(
            'There should be a CheckpointSaverHook to use saving_listeners. '
            'Please set one of the RunConfig.save_checkpoints_steps or '
            'RunConfig.save_checkpoints_secs.')
      else:
        # It is expected to have one CheckpointSaverHook. If multiple, we pick
        # up the first one to add listener.
        saver_hooks[0]._listeners.extend(saving_listeners)  # pylint: disable=protected-access

    # Add summary hooks to worker 0 if we are running with a master, to ensure
    # that summaries are written at correct intervals even with long-running
    # evaluations.
    save_summary_steps = self._config.save_summary_steps
    log_step_count_steps = self._config.log_step_count_steps

    # Check existence of appropriate cluster spec fields, as well as master and
    # worker nodes. As master also performs evaluation, summary writing must
    # occur on a different node. The presence of a worker is also checked to
    # prevent reassigning hooks for single-replica jobs with just a master node.
    if (self._config.cluster_spec and self._config.cluster_spec.jobs and
        (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
        (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
      # Update config values to prevent the default hooks from being created on
      # the master or other workers.
      save_summary_steps = 0
      log_step_count_steps = None

      if (self._config.task_type == run_config.TaskType.WORKER and
          self._config.task_id == 0):
        if (self._config.save_summary_steps and
            self._config.save_summary_steps > 0):
          worker_hooks.append(
              training.SummarySaverHook(
                  save_steps=self._config.save_summary_steps,
                  output_dir=self._config.model_dir,
                  scaffold=estimator_spec.scaffold))

        if (self._config.log_step_count_steps and
            self._config.log_step_count_steps > 0):
          worker_hooks.append(
              training.StepCounterHook(
                  every_n_steps=self._config.log_step_count_steps,
                  output_dir=self._config.model_dir))
    # session 会话运行的开始
    with training.MonitoredTrainingSession(
        master=self._config.master,
        is_chief=self._config.is_chief,
        checkpoint_dir=self._model_dir,
        scaffold=estimator_spec.scaffold,
        hooks=worker_hooks,
        chief_only_hooks=(
            tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
        save_checkpoint_secs=0,  # Saving is handled by a hook.
        save_summaries_steps=save_summary_steps,
        config=self._session_config,
        log_step_count_steps=log_step_count_steps) as mon_sess:
      loss = None
      any_step_done = False
      while not mon_sess.should_stop():
        _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
        any_step_done = True
    if not any_step_done:
      logging.warning('Training with estimator made no steps. '
                      'Perhaps input is empty or misspecified.')
    return loss

对于上面的monitorsession 中hook 调用等同于下面的逻辑

# your_hooks表示一系列 SessionRunHook 对象
with MonitoredTrainingSession(hooks=your_hooks, ...) as sess:
    while not sess.should_stop():
        sess.run(your_fetches)

# 该方法等价于:
call hooks.begin() # begin
sess = tf.Session()
call hooks.after_create_session()  # after_create_session
while not stop is requested:
    call hooks.before_run()  # before_run
    try:
        results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
        break
    call hooks.after_run()  # after_run
call hooks.end()  # end
sess.close()

在解析代码的时候会发现有个类就是CheckpointSaverListener是用于监听CheckpointSaverHook的接口,用于在checkpoint save操作之前和之后进行一系列定制的操作。

@tf_export("train.CheckpointSaverListener")
class CheckpointSaverListener(object):
  """Interface for listeners that take action before or after checkpoint save.

  `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is
  triggered, and provides callbacks at the following points:
   - before using the session
   - before each call to `Saver.save()`
   - after each call to `Saver.save()`
   - at the end of session

  To use a listener, implement a class and pass the listener to a
  `CheckpointSaverHook`, as in this example:

  ```python
  class ExampleCheckpointSaverListener(CheckpointSaverListener):
    def begin(self):
      # You can add ops to the graph here.
      print('Starting the session.')
      self.your_tensor = ...

    def before_save(self, session, global_step_value):
      print('About to write a checkpoint')

    def after_save(self, session, global_step_value):
      print('Done writing checkpoint.')
      if decided_to_stop_training():
        return True

    def end(self, session, global_step_value):
      print('Done with the session.')

  ...
  listener = ExampleCheckpointSaverListener()
  saver_hook = tf.train.CheckpointSaverHook(
      checkpoint_dir, listeners=[listener])
  with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
    ...
  ```

  A `CheckpointSaverListener` may simply take some action after every
  checkpoint save. It is also possible for the listener to use its own schedule
  to act less frequently, e.g. based on global_step_value. In this case,
  implementors should implement the `end()` method to handle actions related to
  the last checkpoint save. But the listener should not act twice if
  `after_save()` already handled this last checkpoint save.

  A `CheckpointSaverListener` can request training to be stopped, by returning
  True in `after_save`. Please note that, in replicated distributed training
  setting, only `chief` should use this behavior. Otherwise each worker will do
  their own evaluation, which may be wasteful of resources.
  """

  def begin(self):
    pass

  def before_save(self, session, global_step_value):
    pass

  def after_save(self, session, global_step_value):
    pass

  def end(self, session, global_step_value):
    pass

接着看一下在保存完 checkpoint 之后在那进行的evaluate操作

class _NewCheckpointListenerForEvaluate(
    tf.compat.v1.train.CheckpointSaverListener):
  """A saver listener to run evaluate with every checkpoint."""

  def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener):
    self._evaluator = evaluator
    self._eval_throttle_secs = eval_throttle_secs
    self._continuous_eval_listener = continuous_eval_listener
    self.eval_result, self.export_results = None, None

  def begin(self):
    self._timer = basic_session_run_hooks.SecondOrStepTimer(
        every_secs=self._eval_throttle_secs)
    self._is_first_run = True

  def after_save(self, session, global_step_value):
    del session  # unused; required by signature.
    # skip first run model is not trained yet.
    if self._is_first_run:
      self._is_first_run = False
      return

    if not self._continuous_eval_listener.before_eval():
      tf.compat.v1.logging.info(
          'Exiting training and evaluation loop, as requested by '
          '_ContinuousEvalListener.before_eval.')
      return True
    if self._timer.should_trigger_for_step(global_step_value):
	    # 此处调用了 evaluate 操作,作为一个 listener session会话中会调用多次
      self._evaluate(global_step_value)  # updates self.eval_result
      if not self._continuous_eval_listener.after_eval(self.eval_result):
        tf.compat.v1.logging.info('Exiting evaluation, as requested by '
                                  '_ContinuousEvalListener.after_eval.')
        return True
    else:
      # TODO(ispir): add remaining time in the log.
      tf.compat.v1.logging.info(
          'Skip the current checkpoint eval due to throttle secs '
          '({} secs).'.format(self._eval_throttle_secs))

  def end(self, session, global_step_value):
    # Evaluate if the last step has not been evaluated, yet.
    if global_step_value != self._timer.last_triggered_step():
      if self._continuous_eval_listener.before_eval():
        self._evaluate(global_step_value)
        self._continuous_eval_listener.after_eval(self.eval_result)

  def _evaluate(self, global_step_value):
    self._timer.update_last_triggered_step(global_step_value)
    self.eval_result, self.export_results = (
        self._evaluator.evaluate_and_export())
    if self.eval_result.status != _EvalStatus.EVALUATED:
      #  This is unexpected; should never happen.
      #  Training should always end with a new checkpoint.
      raise RuntimeError('There was no new checkpoint after the training. '
                         'Eval status: {}'.format(self.eval_result.status))
正文完
请博主喝杯咖啡吧!
post-qrcode
 3
admin
版权声明:本站原创文章,由 admin 2021-01-07发表,共计34281字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(一条评论)
验证码