TensorFlow 中的metric包里方法简介

241次阅读
没有评论

metric包里主要是用来做些衡量指标的,mean、accuracy等指标的计算方法都在这。这些计算的指标值顺便写入到 summary或者在logger hook 里打印都可以。在指标计算的地方有一处,就是返回值会有点让人迷惑,我们 sdk 还改了分布式验证,在此基础上加了 allreduce 操作,比原生的又多了一步。

mean 作为介绍的范例:

@tf_export(v1=['metrics.mean'])
def mean(values,
         weights=None,
         metrics_collections=None,
         updates_collections=None,
         name=None):
  """Computes the (weighted) mean of the given values.
  The `mean` function creates two local variables, `total` and `count`
  that are used to compute the average of `values`. This average is ultimately
  returned as `mean` which is an idempotent operation that simply divides
  `total` by `count`.
  For estimation of the metric over a stream of data, the function creates an
  `update_op` operation that updates these variables and returns the `mean`.
  `update_op` increments `total` with the reduced sum of the product of `values`
  and `weights`, and it increments `count` with the reduced sum of `weights`.
  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
  Args:
    values: A `Tensor` of arbitrary dimensions.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `values`, and must be broadcastable to `values` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `values` dimension).
    metrics_collections: An optional list of collections that `mean`
      should be added to.
    updates_collections: An optional list of collections that `update_op`
      should be added to.
    name: An optional variable_scope name.
  Returns:
    mean: A `Tensor` representing the current mean, the value of `total` divided
      by `count`.
    update_op: An operation that increments the `total` and `count` variables
      appropriately and whose value matches `mean_value`.
  Raises:
    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
      or if either `metrics_collections` or `updates_collections` are not a list
      or tuple.
    RuntimeError: If eager execution is enabled.
  """
  if context.executing_eagerly():
    raise RuntimeError('tf.metrics.mean is not supported when eager execution '
                       'is enabled.')

  with variable_scope.variable_scope(name, 'mean', (values, weights)):
    values = math_ops.cast(values, dtypes.float32)

    total = metric_variable([], dtypes.float32, name='total')
    count = metric_variable([], dtypes.float32, name='count')

    if weights is None:
      num_values = math_ops.cast(array_ops.size(values), dtypes.float32)
    else:
      values, _, weights = _remove_squeezable_dimensions(
          predictions=values, labels=None, weights=weights)
      weights = weights_broadcast_ops.broadcast_weights(
          math_ops.cast(weights, dtypes.float32), values)
      values = math_ops.multiply(values, weights)
      num_values = math_ops.reduce_sum(weights)

    update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
    with ops.control_dependencies([values]):
      update_count_op = state_ops.assign_add(count, num_values)

    def compute_mean(_, t, c):
      return math_ops.div_no_nan(t, math_ops.maximum(c, 0), name='value')

    mean_t = _aggregate_across_replicas(
        metrics_collections, compute_mean, total, count)
    update_op = math_ops.div_no_nan(
        update_total_op, math_ops.maximum(update_count_op, 0), name='update_op')

    if updates_collections:
      ops.add_to_collections(updates_collections, update_op)

    return mean_t, update_op

上面可以看到返回值是 mean_t 和 update_op

mean_t

mean_t = _aggregate_across_replicas(
        metrics_collections, compute_mean, total, count)

这一步的计算就是使用 total/count 计算的是一个Batch 里面的平均值。 你训练的数据需要 N 个Bacth,这个值计算只是你其中一个Batch的结果。

update_op

update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))

update_op的计算依赖上面的算子,assign_add返回的是一个OP,如果想要得到total的值需要run这个OP,而且计算的结果重新赋给total。到这里可以明白了其实 total 记录的是你训练到现在K 个 Batch的累计结果。

附言

前面也提到分布式验证,这里暂时不贴代码,描述一下。会借助 Allreduce 来获取对应的值,主要在验证的时候使用。

admin
版权声明:本站原创文章,由admin2021-08-26发表,共计3283字。
转载提示:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)