小记 TensorFlow 中 BN 问题

2,282次阅读
没有评论

共计 1477 个字符,预计需要花费 4 分钟才能阅读完成。

日常我用 BN 的机会不是太多,这个可能在图像领域会使用的比较多。使用 BN 会加速你模型训练的速度,这个跟 BP 反向传播梯度更新有很大的关系,这个应该都知道的。

今天说的这个跟TF 的使用上有一定的关系。BN 的操作是针对一批样本进行处理的,在训练的时候会根据训练数据不断的更新自己的参数。具体的API 如下所示:

tf.layers.batch_normalization(
    inputs,
    axis=-1,
    momentum=0.99,
    epsilon=0.001,
    center=True,
    scale=True,
    beta_initializer=tf.zeros_initializer(),
    gamma_initializer=tf.ones_initializer(),
    moving_mean_initializer=tf.zeros_initializer(),
    moving_variance_initializer=tf.ones_initializer(),
    beta_regularizer=None,
    gamma_regularizer=None,
    beta_constraint=None,
    gamma_constraint=None,
    training=False,
    trainable=True,
    name=None,
    reuse=None,
    renorm=False,
    renorm_clipping=None,
    renorm_momentum=0.99,
    fused=None,
    virtual_batch_size=None,
    adjustment=None
)

主要需要注意的是

beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
moving_mean_initializer=tf.zeros_initializer(),
moving_variance_initializer=tf.ones_initializer(),

这四个参数是在模型训练过程中可训练的。现在日常使用的版本都是1.14版本的TF ,主要使用Estimator 如果不特殊指定 moving_mean和 moving_variance 并不会被保存到 checkpoint 里,导致里在inference的时候出现 很大的偏差,针对这类的问题需要在 Estimator model_fn做出一定的改动。

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

这个在官方文档里也有记录

Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_optf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op

题外话

在线上 Serving的时候都是对一个用户全部预估候选Item ,这个跟实际线下存在一定的差异,所以感觉不是很适合。对比图像领域不太一样,图像信息固定,学习有一定的规律。

正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2020-12-01发表,共计1477字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码