• 为了保证你在浏览本网站时有着更好的体验,建议使用类似Chrome、Firefox之类的浏览器~~
    • 如果你喜欢本站的内容何不Ctrl+D收藏一下呢,与大家一起分享各种编程知识~
    • 本网站研究机器学习、计算机视觉、模式识别~当然不局限于此,生命在于折腾,何不年轻时多折腾一下

TensorFlow SaveModel API 入门

Tensorflow admin 来源:掘金:水木云石 3周前 (04-29) 161次浏览 0个评论 扫描二维码

最近换了新坑,组里的 tensorflow 版本基本上都赶到了最新的 1.13 版本了,我最近写的代码都是基于 1.13.1 了,其中有一个就是 tensorflow 保存模型这块,使用了 SavedModel。

为什么不使用 checkpoint?

Saver.restore()需要提前建立好计算图,这在理论上是可行的,但是对于模型跨平台来说,成本和效率都存在问题,当模型趋于复杂,序列模型、深度卷积、复杂全连接以及种种超参数以及优化技术都需要两端完全匹配,就目前来看是得不偿失的。


这两天搜索了不少关于 Tensorflow 模型保存与加载的资料,发现很多资料都是关于 checkpoints 模型格式的,而最新的 SavedModel 模型格式则资料较少,为此总结一下 TensorFlow 如何保存 SavedModel 模型,并加载之。

为什么要采用 SavedModel 格式呢?其主要优点是 SaveModel 与语言无关,比如可以使用 python 语言训练模型,然后在 Java 中非常方便的加载模型。当然这也不是说 checkpoints 模型格式做不到,只是在跨语言时比较麻烦。另外如果使用 Tensorflow Serving server 来部署模型,必须选择 SavedModel 格式。

SavedModel 包含啥?

一个比较完整的 SavedModel 模型包含以下内容:

assets/
assets.extra/
variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

saved_model.pb 是 MetaGraphDef,它包含图形结构。variables 文件夹保存训练所习得的权重。assets 文件夹可以添加可能需要的外部文件,assets.extra 是一个库可以添加其特定 assets 的地方。

MetaGraph 是一个数据流图,加上其相关的变量、assets 和签名。MetaGraphDef 是 MetaGraph 的 Protocol Buffer 表示。

assets 和 assets.extra 是可选的,比如本文示例代码保存的模型只包含以下的内容:

variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

保存

为了简单起见,我们使用一个非常简单的手写识别代码作为示例,代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

这段代码很简单,一个简单的梯度递减回归模型。要保存该模型,我们还需要对代码作一点小小的改动。

添加命名

在输入和输出 Ops 中添加名称,这样我们在加载时可以方便的按名称引用操作。将上述的 x 赋值语句修改为:

x = tf.placeholder(tf.float32, [None, 784], name="myInput")

当然你也可以不给名称,系统会默认给一个名称,比如上面的 x 系统会给一个”Placeholder”,当我们需要引用多个 op 的时候,给每个 op 一个命名,确实方便给我们后面使用。

你也可以使用 tf.identity 给 tensor 命名,比如在上述代码上添加一行:

tf.identity(y, name="myOutput")

给输出也命一个名。

保存到文件

最简单的保存方法是使用 tf.saved_model.simple_save 函数,代码如下:

tf.saved_model.simple_save(sess,
            "./model",
            inputs={"myInput": x},
            outputs={"myOutput": y})

这段代码将模型保存在**./model**目录。

当然你也可以采用比较复杂的写法,目前使用的也是以下定义的方式:

关于输入与输出可以使用

tf.saved_model.utils.build_tensor_info

来构建,就是将张量转为 protobuff 结构的快捷方法,也就是说下面的输入 x 以及输出 y 都是经过该函数处理之后的结果。

builder = tf.saved_model.builder.SavedModelBuilder("./model")

signature = predict_signature_def(inputs={'myInput': x},
                                  outputs={'myOutput': y})
builder.add_meta_graph_and_variables(sess=sess,
                                     tags=[tag_constants.SERVING],
                                     signature_def_map={'predict': signature})
builder.save()

signature对象,这个对象包含了计算图中输入与输出张量的键值对信息,键即是张量名,值即是 protobuff 结构的张量。

看起来新的代码差别不大,区别就在于可以自己定义 tag,在签名的定义上更加灵活。这里说说 tag 的用途吧。

一个模型可以包含不同的 MetaGraphDef,什么时候需要多个 MetaGraphDef 呢?也许你想保存图形的 CPU 版本和 GPU 版本,或者你想区分训练和发布版本。这个时候 tag 就可以用来区分不同的 MetaGraphDef,加载的时候能够根据 tag 来加载模型的不同计算图。

在 simple_save 方法中,系统会给一个默认的 tag: “serve”,也可以用 tag_constants.SERVING 这个常量。

signature_def_map={'predict': signature})  这个参数非常的重要,在你这里指定了模型的输入与输出。

加载

对不同语言而言,加载过程有些类似,这里还是以 python 为例:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, ["serve"], "./model")
  graph = tf.get_default_graph()

  input = np.expand_dims(mnist.test.images[0], 0)
  x = sess.graph.get_tensor_by_name('myInput:0')
  y = sess.graph.get_tensor_by_name('myOutput:0')
  batch_xs, batch_ys = mnist.test.next_batch(1)
  scores = sess.run(y,
           feed_dict={x: batch_xs})
  print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))

需要注意,load 函数中第二个参数是 tag,需要和保存模型时的参数一致,第三个参数是模型保存的文件夹。

调用 load 函数后,不仅加载了计算图,还加载了训练中习得的变量值,有了这两者,我们就可以调用其进行推断新给的测试数据。

小结

将过程捋顺了之后,你会发觉保存和加载 SavedModel 其实很简单。但在摸索过程中,也走了不少的弯路,主要原因是现在搜索到的大部分资料还是用 tf.train.Saver()来保存模型,还有的是用 tf.gfile.FastGFile 来序列化模型图。


Deeplearn, 版权所有丨如未注明 , 均为原创丨本网站采用BY-NC-SA协议进行授权 , 转载请注明TensorFlow SaveModel API 入门
喜欢 (0)
admin
关于作者:

您必须 登录 才能发表评论!