起初是在代码Review的时候有人提问道Estimator到底在哪管理着Session会话,Emmm,之前代码么有仔细的看过,一时还真的不知道。然后在网上搜还是花了点时间,大部分都是说不用去管Session,意指已经是High Level Api了就不用管这些Session问题了,本来Estimator的设计也是考虑了这点,说到底还是
def train(self, input_fn, hooks=None, steps=None, max_steps=None, saving_listeners=None): """Trains a model given training data `input_fn`. Args: input_fn: A function that provides input data for training as minibatches. See [Premade Estimators]( https://tensorflow.org/guide/premade_estimators#create_input_functions) for more information. The function should construct and return one of the following: * A `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple `(features, labels)` with same constraints as below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary of string feature name to `Tensor` and `labels` is a `Tensor` or a dictionary of string label name to `Tensor`. Both `features` and `labels` are consumed by `model_fn`. They should satisfy the expectation of `model_fn` from inputs. hooks: List of `tf.train.SessionRunHook` subclass instances. Used for callbacks inside the training loop. steps: Number of steps for which to train the model. If `None`, train forever or train until `input_fn` generates the `tf.errors.OutOfRange` error or `StopIteration` exception. `steps` works incrementally. If you call two times `train(steps=10)` then training occurs in total 20 steps. If `OutOfRange` or `StopIteration` occurs in the middle, training stops before 20 steps. If you don't want to have incremental behavior please set `max_steps` instead. If set, `max_steps` must be `None`. max_steps: Number of total steps for which to train model. If `None`, train forever or train until `input_fn` generates the `tf.errors.OutOfRange` error or `StopIteration` exception. If set, `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the middle, training stops before `max_steps` steps. Two calls to `train(steps=100)` means 200 training iterations. On the other hand, two calls to `train(max_steps=100)` means that the second call will not do any iteration since first call did all 100 steps. saving_listeners: list of `CheckpointSaverListener` objects. Used for callbacks that run immediately before or after checkpoint savings. Returns: `self`, for chaining. Raises: ValueError: If both `steps` and `max_steps` are not `None`. ValueError: If either `steps` or `max_steps <= 0`. """ _estimator_api_gauge.get_cell('train').set(True) if self.config.task_type in (run_config.TaskType.EVALUATOR, run_config.TaskType.PS): raise ValueError( 'Train has been called wrong configuration. Please use ' 'tf.estimator.train_and_evaluate which calls proper API according ' 'to given configuration. Current configuration: {}.'.format( self.config)) with context.graph_mode(): if (steps is not None) and (max_steps is not None): raise ValueError('Can not provide both steps and max_steps.') if steps is not None and steps <= 0: raise ValueError('Must specify steps > 0, given: {}'.format(steps)) if max_steps is not None and max_steps <= 0: raise ValueError( 'Must specify max_steps > 0, given: {}'.format(max_steps)) if max_steps is not None: # 从 checkpoint里恢复GlobalStep start_step = _load_global_step_from_checkpoint_dir(self._model_dir) if max_steps <= start_step: logging.info('Skipping training since max_steps has already saved.') return self # 检查 hooks类型,确认是不是集成来自sessionrunhook hooks = _check_hooks_type(hooks) hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps)) #saving_listeners里面存放的是保存checkpoint hook saving_listeners = _check_listeners_type(saving_listeners) # 这个_train_model是重点,之后发现Session也是在这里面会去找到 loss = self._train_model(input_fn, hooks, saving_listeners) logging.info('Loss for final step: %s.', loss) return self
参数:
input_fn:用于给训练过程提供minibatches的数据的函数,使用详情可以参考;Premade Estimators。该函数的返回值必须是以下几种之一: (1) A tf.data.Dataset object: Dataset 的输出必须是(features, labels) 元组,它的格式要求和下面相同。 (2) A tuple(features, labels): 其中, features是一个 tf.Tensor 或者是以string为key,以Tensor为value的字典。 labels 同理。 features 和labels 都是供 model_fn消费的, 它们必须满足model_fn 的输入要求。
hooks: 一个包含若干 tf.train.SessionRunHook 子类实例的list,用于在训练过程中的回调。说点个人理解的事情,这个hooks是一个为estimator服务的类,它有begin、after_create_session、before_run、after_run、end方法,分别用于在创建Session之前、创建Session之后、Session运行之前、Session运行之后以及Session即将关闭之前执行一些需要的操作。[参考代码](## 附录)
steps: 模型训练的步数。如果是None, 模型将会一直训练下去,或者input_fn 遇到 tf.errors.OutOfRange的error或者StopIteration 的exception。steps 可以增量训练。例如,你先后调用了两次train(steps=10) ,那么总的训练步数是20步。如果在中间过程中发生了OutOfRange 或 StopIteration ,训练过程将在20步之前终止。如果你不想使用增量式的训练方式,请设置max_steps 参数. 如果设置了steps参数, max_steps必须设为 None。
max_steps: 模型训练的总步数,如果设为 None,模型一直训练直到 input_fn 发生tf.errors.OutOfRange error 或者StopIteration exception。如果设置了该参数,steps 必须设为 None。训练过程中如果遇到了 OutOfRange 或者 StopIteration ,训练过程将会在 max_steps 之前终止。 调用两次train(steps=100) 意味着总的训练步数为200,而两次调用train(max_steps=100) 只会训练100次,因为第一次的调用已经达到了最大训练次数。
saving_listeners: CheckpointSaverListener 对象list. 用于checkpoint savings执行前后的立即回调过程。
def _train_model(self, input_fn, hooks, saving_listeners): if self._train_distribution: return self._train_model_distributed(input_fn, hooks, saving_listeners) else: return self._train_model_default(input_fn, hooks, saving_listeners) def _train_model_default(self, input_fn, hooks, saving_listeners): """Initiate training with `input_fn`, without `DistributionStrategies`. Args: input_fn: A function that provides input data for training as minibatches. hooks: List of `tf.train.SessionRunHook` subclass instances. Used for callbacks inside the training loop. saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used for callbacks that run immediately before or after checkpoint savings. Returns: Loss from training """ worker_hooks = [] with tf.Graph().as_default() as g, g.device(self._device_fn): tf.compat.v1.random.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) # Skip creating a read variable if _create_and_assert_global_step # returns None (e.g. tf.contrib.estimator.SavedModelEstimator). if global_step_tensor is not None: training_util._get_or_create_global_step_read(g) # pylint: disable=protected-access # 本质上是从input_fn里去获取数据 features, labels, input_hooks = ( self._get_features_and_labels_from_input_fn(input_fn, ModeKeys.TRAIN)) worker_hooks.extend(input_hooks) # 调用自定义的Model_fn,这里面会有定义好各个模式下的配置和运行逻辑 estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN, self.config) global_step_tensor = tf.compat.v1.train.get_global_step(g) # 返回 Estimator Spec ,至此这个函数里面会包含Session的定义 return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners)
继续向下钻就会发现Session会话的控制:
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners): """Train a model with the given Estimator Spec.""" if (self._warm_start_settings and not tf.train.latest_checkpoint(self._model_dir)): tf.compat.v1.logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings,)) tf.compat.v1.train.warm_start(*self._warm_start_settings) # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will # make another one with the name 'loss' to ensure it shows up in the right # graph in TensorBoard. if not any([ x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES) ]): summary.scalar('loss', estimator_spec.loss) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) worker_hooks.extend(hooks) worker_hooks.append(tf.compat.v1.train.NanTensorHook(estimator_spec.loss)) if self._config.log_step_count_steps is not None: worker_hooks.append( tf.compat.v1.train.LoggingTensorHook( { 'loss': estimator_spec.loss, 'step': global_step_tensor }, every_n_iter=self._config.log_step_count_steps)) worker_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SAVERS)): tf.compat.v1.add_to_collection( tf.compat.v1.GraphKeys.SAVERS, tf.compat.v1.train.Saver( sharded=True, max_to_keep=self._config.keep_checkpoint_max, keep_checkpoint_every_n_hours=( self._config.keep_checkpoint_every_n_hours), defer_build=True, save_relative_paths=True)) if (self._config.cluster_spec and type( self._train_distribution).__name__ in ('CollectiveAllReduceStrategy', 'CollectiveAllReduceStrategyV1', 'MultiWorkerMirroredStrategy')): return self._train_with_estimator_spec_distributed( estimator_spec, worker_hooks, saving_listeners) chief_hooks = [] all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks) saver_hooks = [ h for h in all_hooks if isinstance(h, tf.compat.v1.train.CheckpointSaverHook) ] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): if not saver_hooks: chief_hooks = [ tf.compat.v1.train.CheckpointSaverHook( self._model_dir, save_secs=self._config.save_checkpoints_secs, save_steps=self._config.save_checkpoints_steps, scaffold=estimator_spec.scaffold, save_graph_def=self._config.checkpoint_save_graph_def) ] saver_hooks = [chief_hooks[0]] if saving_listeners: if not saver_hooks: raise ValueError( 'There should be a CheckpointSaverHook to use saving_listeners. ' 'Please set one of the RunConfig.save_checkpoints_steps or ' 'RunConfig.save_checkpoints_secs.') else: # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. for listener in saving_listeners: # pylint: disable=protected-access if listener not in saver_hooks[0]._listeners: saver_hooks[0]._listeners.append(listener) # pylint: disable=protected-access # Add summary hooks to worker 0 if we are running with a master, to ensure # that summaries are written at correct intervals even with long-running # evaluations. save_summary_steps = self._config.save_summary_steps log_step_count_steps = self._config.log_step_count_steps # Check existence of appropriate cluster spec fields, as well as master and # worker nodes. As master also performs evaluation, summary writing must # occur on a different node. The presence of a worker is also checked to # prevent reassigning hooks for single-replica jobs with just a master node. if (self._config.cluster_spec and self._config.cluster_spec.jobs and (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)): # Update config values to prevent the default hooks from being created on # the master or other workers. save_summary_steps = 0 log_step_count_steps = None if (self._config.task_type == run_config.TaskType.WORKER and self._config.task_id == 0): if (self._config.save_summary_steps and self._config.save_summary_steps > 0): worker_hooks.append( tf.compat.v1.train.SummarySaverHook( save_steps=self._config.save_summary_steps, output_dir=self._config.model_dir, scaffold=estimator_spec.scaffold)) if (self._config.log_step_count_steps and self._config.log_step_count_steps > 0): worker_hooks.append( tf.compat.v1.train.StepCounterHook( every_n_steps=self._config.log_step_count_steps, output_dir=self._config.model_dir)) # 哈哈,还是通过MonitorTrainsession来控制 with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=(tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=save_summary_steps, config=self._session_config, max_wait_secs=self._config.session_creation_timeout_secs, log_step_count_steps=log_step_count_steps, save_graph_def=self._config.checkpoint_save_graph_def) as mon_sess: loss = None any_step_done = False while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) any_step_done = True if not any_step_done: tf.compat.v1.logging.warn('Training with estimator made no steps. ' 'Perhaps input is empty or misspecified.') return loss
到最后是找到了EstimatorSession会话控制的位置。其实还有一些问题待解决,Estimator如何切换 Mode的,这个还在看,觉得应该在Train_and_evaluate里面能找大答案。