什么是checkpoint?
检查点checkpoint中存储着模型model所使用的的所有的 tf.Variable 对象以及模型结构的定义
checkpoint的一般格式如下:
(1)meta文件
.meta文件:一个协议缓冲,保存tensorflow中完整的graph、variables、operation、collection;这是我们恢复模型结构的参照;
meta文件保存的是图结构,通俗地讲就是神经网络的网络结构。我们可以使用下面的代码只在第一次保存meta文件。
<code>saver.save(sess, 'my_model.ckpt', global_step=step, write_meta_graph=False) </code>
在后面恢复整个graph的结构的时候,并且还可以使用
<code>tf.train.import_meta_graph(‘xxxxxx.meta’) </code>
能够导入图结构。
(2)data文件
keypoint_model.ckpt-9.data-00000-of-00001:数据文件,保存的是网络的权值,偏置,操作等等。
(3)index文件
keypoint_model.ckpt-9.index 是一个不可变得字符串字典,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据,所谓的元数据就是描述这个Variable 的一些信息的数据。 “数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
Note: 以前的版本中tensorflow的model只保存一个文件中。
(4)checkpoint文件——文本文件
checkpoint是一个文本文件,记录了训练过程中在所有中间节点上保存的模型的名称,首行记录的是最后(最近)一次保存的模型名称。checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;比如我上面的模型保存了最后的5份checkpoint,这里打开checkpoint查看得到如下内容:
<code>model_checkpoint_path: "keypoint_model.ckpt-9" # 最新的那一份 all_model_checkpoint_paths: "keypoint_model.ckpt-5" all_model_checkpoint_paths: "keypoint_model.ckpt-6" all_model_checkpoint_paths: "keypoint_model.ckpt-7" all_model_checkpoint_paths: "keypoint_model.ckpt-8" all_model_checkpoint_paths: "keypoint_model.ckpt-9 </code>
实例
<code class="language-python">import tensorflow as tf import numpy as np # 1.准备数据: x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 # 2.构造一个线性模型 w = tf.Variable(tf.random_normal([1], -1, 1)) #创建新对象,当检测到命名冲突时,系统会自己处理 b = tf.Variable(tf.zeros([1])) y_predict = w * x + b # 3.求解模型 # 设置损失函数:误差的均方差 loss = tf.reduce_mean(tf.square(y - y_predict)) # 选择梯度下降的方法 optimizer = tf.train.GradientDescentOptimizer(0.5) # 迭代的目标:最小化损失函数 train = optimizer.minimize(loss) #参数定义声明 isTrain = True train_steps = 100 checkpoint_steps = 50 checkpoint_dir = 'F:\\\\vivocode\\\\tftestmodel\\\\' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) ############################################################ # 以下是用 tf 来解决上面的任务 # 1.初始化变量:tf 的必备步骤,主要声明了变量,就必须初始化才能用 # init = tf.global_variables_initializer() # 设置tensorflow对GPU的使用按需分配 #config = tf.ConfigProto() #config.gpu_options.allow_growth = True # 2.启动图 (graph) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) #判断当前工作状态 if isTrain: #isTrain:True表示训练;False:表示测试 # 3.迭代,反复执行上面的最小化损失函数这一操作(train op),拟合平面 for i in range(train_steps): #train_steps表示训练的次数,例子中使用1006666666 sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: #表示训练多少次保存一下checkpoints,例子中使用50 print ('step: {} train_acc: {} loss: {}'.format(i, sess.run(w), sess.run(b))) saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) #表示checkpoints文件的保存路径,例子中使用当前路径 else: #如果isTrain=False,则进行测试 ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) #恢复变量 else: pass print(sess.run(w),sess.run(b)) </code>
生成的checkpoint 如下所示:
load 数据
如果你导入meta数据是不需要在定义图就可以做些运算
<code class="language-python">import tensorflow as tf import numpy as np checkpoint_dir = 'F:\\\\vivocode\\\\tftestmodel\\\\' # 2.启动图 (graph) # saver = tf.train.Saver() with tf.Session() as sess: model_path = tf.train.latest_checkpoint(checkpoint_dir) # 获取最新的模型,注意这里的是文件夹哦 saver=tf.train.import_meta_graph(model_path+'.meta') saver.restore(sess,model_path) graph = tf.get_default_graph() X = graph.get_tensor_by_name('Variable:0') print(sess.run(X)) [3.985119] </code>