TensorFlow SaveModel API 入门

3,782次阅读
没有评论

最近换了新坑,组里的tensorflow版本基本上都赶到了最新的1.13版本了,我最近写的代码都是基于1.13.1了,其中有一个就是tensorflow 保存模型这块,使用了SavedModel。

为什么不使用checkpoint?

Saver.restore()需要提前建立好计算图,这在理论上是可行的,但是对于模型跨平台来说,成本和效率都存在问题,当模型趋于复杂,序列模型、深度卷积、复杂全连接以及种种超参数以及优化技术都需要两端完全匹配,就目前来看是得不偿失的。


这两天搜索了不少关于Tensorflow模型保存与加载的资料,发现很多资料都是关于checkpoints模型格式的,而最新的SavedModel模型格式则资料较少,为此总结一下TensorFlow如何保存SavedModel模型,并加载之。

为什么要采用SavedModel格式呢?其主要优点是SaveModel与语言无关,比如可以使用python语言训练模型,然后在Java中非常方便的加载模型。当然这也不是说checkpoints模型格式做不到,只是在跨语言时比较麻烦。另外如果使用Tensorflow Serving server来部署模型,必须选择SavedModel格式。

SavedModel包含啥?

一个比较完整的SavedModel模型包含以下内容:

assets/
assets.extra/
variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

saved_model.pb是MetaGraphDef,它包含图形结构。variables文件夹保存训练所习得的权重。assets文件夹可以添加可能需要的外部文件,assets.extra是一个库可以添加其特定assets的地方。

MetaGraph是一个数据流图,加上其相关的变量、assets和签名。MetaGraphDef是MetaGraph的Protocol Buffer表示。

assets和assets.extra是可选的,比如本文示例代码保存的模型只包含以下的内容:

variables/
    variables.data-*****-of-*****
    variables.index
saved_model.pb

保存

为了简单起见,我们使用一个非常简单的手写识别代码作为示例,代码如下:

<span class="hljs-keyword">from</span> tensorflow.examples.tutorials.mnist <span class="hljs-keyword">import</span> input_data
<span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf

mnist = input_data.read_data_sets(<span class="hljs-string">"MNIST_data/"</span>, one_hot=<span class="hljs-keyword">True</span>)

sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [<span class="hljs-keyword">None</span>, <span class="hljs-number">784</span>])
W = tf.Variable(tf.zeros([<span class="hljs-number">784</span>, <span class="hljs-number">10</span>]))
b = tf.Variable(tf.zeros([<span class="hljs-number">10</span>]))

y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [<span class="hljs-keyword">None</span>, <span class="hljs-number">10</span>])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), <span class="hljs-number">1</span>))

train_step = tf.train.GradientDescentOptimizer(<span class="hljs-number">0.5</span>).minimize(cross_entropy)
tf.global_variables_initializer().run()

<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">1000</span>):
    batch_xs, batch_ys = mnist.train.next_batch(<span class="hljs-number">100</span>)
    train_step.run({x: batch_xs, y_: batch_ys})

correct_prediction = tf.equal(tf.argmax(y, <span class="hljs-number">1</span>), tf.argmax(y_, <span class="hljs-number">1</span>))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

这段代码很简单,一个简单的梯度递减回归模型。要保存该模型,我们还需要对代码作一点小小的改动。

添加命名

在输入和输出Ops中添加名称,这样我们在加载时可以方便的按名称引用操作。将上述的x赋值语句修改为:

x = tf.placeholder(tf.float32, [None, 784], name=<span class="hljs-string">"myInput"</span>)

当然你也可以不给名称,系统会默认给一个名称,比如上面的x系统会给一个”Placeholder”,当我们需要引用多个op的时候,给每个op一个命名,确实方便给我们后面使用。

你也可以使用tf.identity给tensor命名,比如在上述代码上添加一行:

tf.identity(y, name=<span class="hljs-string">"myOutput"</span>)

给输出也命一个名。

保存到文件

最简单的保存方法是使用tf.saved_model.simple_save函数,代码如下:

tf.saved_model.simple_save(sess,
            <span class="hljs-string">"./model"</span>,
            inputs={<span class="hljs-string">"myInput"</span>: x},
            outputs={<span class="hljs-string">"myOutput"</span>: y})

这段代码将模型保存在**./model**目录。

当然你也可以采用比较复杂的写法,目前使用的也是以下定义的方式:

关于输入与输出可以使用

tf.saved_model.utils.build_tensor_info

来构建,就是将张量转为protobuff结构的快捷方法,也就是说下面的输入x 以及输出 y都是经过该函数处理之后的结果。

builder = tf.saved_model.builder.SavedModelBuilder(<span class="hljs-string">"./model"</span>)

signature = predict_signature_def(inputs={<span class="hljs-string">'myInput'</span>: x},
                                  outputs={<span class="hljs-string">'myOutput'</span>: y})
builder.add_meta_graph_and_variables(sess=sess,
                                     tags=[tag_constants.SERVING],
                                     signature_def_map={<span class="hljs-string">'predict'</span>: signature})
builder.save()

signature对象,这个对象包含了计算图中输入与输出张量的键值对信息,键即是张量名,值即是protobuff结构的张量。

看起来新的代码差别不大,区别就在于可以自己定义tag,在签名的定义上更加灵活。这里说说tag的用途吧。

一个模型可以包含不同的MetaGraphDef,什么时候需要多个MetaGraphDef呢?也许你想保存图形的CPU版本和GPU版本,或者你想区分训练和发布版本。这个时候tag就可以用来区分不同的MetaGraphDef,加载的时候能够根据tag来加载模型的不同计算图。

在simple_save方法中,系统会给一个默认的tag: “serve”,也可以用tag_constants.SERVING这个常量。

signature_def_map={'predict': signature})  这个参数非常的重要,在你这里指定了模型的输入与输出。

加载

对不同语言而言,加载过程有些类似,这里还是以python为例:

mnist = input_data.read_data_sets(<span class="hljs-string">"MNIST_data/"</span>, one_hot=True)

with tf.Session(graph=tf.Graph()) as sess:
  tf.saved_model.loader.load(sess, [<span class="hljs-string">"serve"</span>], <span class="hljs-string">"./model"</span>)
  graph = tf.get_default_graph()

  input = np.expand_dims(mnist.test.images[0], 0)
  x = sess.graph.get_tensor_by_name(<span class="hljs-string">'myInput:0'</span>)
  y = sess.graph.get_tensor_by_name(<span class="hljs-string">'myOutput:0'</span>)
  batch_xs, batch_ys = mnist.test.next_batch(1)
  scores = sess.run(y,
           feed_dict={x: batch_xs})
  <span class="hljs-built_in">print</span>(<span class="hljs-string">"predict: %d, actual: %d"</span> % (np.argmax(scores, 1), np.argmax(batch_ys, 1)))

需要注意,load函数中第二个参数是tag,需要和保存模型时的参数一致,第三个参数是模型保存的文件夹。

调用load函数后,不仅加载了计算图,还加载了训练中习得的变量值,有了这两者,我们就可以调用其进行推断新给的测试数据。

小结

将过程捋顺了之后,你会发觉保存和加载SavedModel其实很简单。但在摸索过程中,也走了不少的弯路,主要原因是现在搜索到的大部分资料还是用tf.train.Saver()来保存模型,还有的是用tf.gfile.FastGFile来序列化模型图。

  • 1 2
admin
版权声明:本站原创文章,由admin2019-04-29发表,共计3782字。
转载提示:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)