TensorFlow SaveModel API 入门

7,449次阅读
没有评论

共计 3782 个字符,预计需要花费 10 分钟才能阅读完成。

最近换了新坑,组里的tensorflow版本基本上都赶到了最新的1.13版本了,我最近写的代码都是基于1.13.1了,其中有一个就是tensorflow 保存模型这块,使用了SavedModel。

为什么不使用checkpoint?

Saver.restore()需要提前建立好计算图,这在理论上是可行的,但是对于模型跨平台来说,成本和效率都存在问题,当模型趋于复杂,序列模型、深度卷积、复杂全连接以及种种超参数以及优化技术都需要两端完全匹配,就目前来看是得不偿失的。


这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。

为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Serving server来部署模型,必须选择SavedModel格式。

SavedModel包含啥?

一个比较完整的SavedModel模型包含以下内容:

assets/
assets.extra/
variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

saved_model.pb是MetaGraphDef,它包含图形结构。variables文件夹保存训练所习得的权重。assets文件夹可以添加可能需要的外部文件,assets.extra是一个库可以添加其特定assets的地方。

MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。

assets和assets.extra是可选的,比如本文示例代码保存的模型只包含以下的内容:

variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

保存

为了简单起见,我们使用一个非常简单的手写识别代码作为示例,代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), 1))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
tf.global_variables_initializer().run()

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_step.run({x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

这段代码很简单,一个简单的梯度递减回归模型。要保存该模型,我们还需要对代码作一点小小的改动。

添加命名

在输入和输出Ops中添加名称,这样我们在加载时可以方便的按名称引用操作。将上述的x赋值语句修改为:

x = tf.placeholder(tf.float32, [None, 784], name="myInput")

当然你也可以不给名称,系统会默认给一个名称,比如上面的x系统会给一个”Placeholder”,当我们需要引用多个op的时候,给每个op一个命名,确实方便给我们后面使用。

你也可以使用tf.identity给tensor命名,比如在上述代码上添加一行:

tf.identity(y, name="myOutput")

给输出也命一个名。

保存到文件

最简单的保存方法是使用tf.saved_model.simple_save函数,代码如下:

tf.saved_model.simple_save(sess,
            "./model",
            inputs={"myInput": x},
            outputs={"myOutput": y})

这段代码将模型保存在**./model**目录。

当然你也可以采用比较复杂的写法,目前使用的也是以下定义的方式:

关于输入与输出可以使用

tf.saved_model.utils.build_tensor_info

来构建,就是将张量转为protobuff结构的快捷方法,也就是说下面的输入x 以及输出 y都是经过该函数处理之后的结果。

builder = tf.saved_model.builder.SavedModelBuilder("./model")

signature = predict_signature_def(inputs={'myInput': x},
                                  outputs={'myOutput': y})
builder.add_meta_graph_and_variables(sess=sess,
                                     tags=[tag_constants.SERVING],
                                     signature_def_map={'predict': signature})
builder.save()

signature对象,这个对象包含了计算图中输入与输出张量的键值对信息,键即是张量名,值即是protobuff结构的张量。

看起来新的代码差别不大,区别就在于可以自己定义tag,在签名的定义上更加灵活。这里说说tag的用途吧。

一个模型可以包含不同的MetaGraphDef,什么时候需要多个MetaGraphDef呢?也许你想保存图形的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。

在simple_save方法中,系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。

signature_def_map={'predict': signature})  这个参数非常的重要,在你这里指定了模型的输入与输出。

加载

对不同语言而言,加载过程有些类似,这里还是以python为例:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, ["serve"], "./model")
  graph = tf.get_default_graph()

  input = np.expand_dims(mnist.test.images[0], 0)
  x = sess.graph.get_tensor_by_name('myInput:0')
  y = sess.graph.get_tensor_by_name('myOutput:0')
  batch_xs, batch_ys = mnist.test.next_batch(1)
  scores = sess.run(y,
           feed_dict={x: batch_xs})
  print("predict: %d, actual: %d" % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))

需要注意,load函数中第二个参数是tag,需要和保存模型时的参数一致,第三个参数是模型保存的文件夹。

调用load函数后,不仅加载了计算图,还加载了训练中习得的变量值,有了这两者,我们就可以调用其进行推断新给的测试数据。

小结

将过程捋顺了之后,你会发觉保存和加载SavedModel其实很简单。但在摸索过程中,也走了不少的弯路,主要原因是现在搜索到的大部分资料还是用tf.train.Saver()来保存模型,还有的是用tf.gfile.FastGFile来序列化模型图。

  • 1 2
正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2019-04-29发表,共计3782字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码