1,646次阅读
没有评论

共计 2761 个字符,预计需要花费 7 分钟才能阅读完成。

图是静态的,无论是加减乘除,只是定义了各种计算关系,不会有实际的任何运算

图的组成

  1. 输入节点
  2. 模型参数
  3. OP

默认计算图

在TensorFlow中会自动维护一个默认的一个计算图,所以我们能够直接定义的tensor或者运算都会被转换为计算图上一个节点。

v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
with tf.Session() as sess:
    # 判断v1所在的graph是否是默认的graph
    print(v1.graph is tf.get_default_graph())
    print(sess.run(add))
    # 输出 True
    # 输出 [[3. 3.]]

我们可以通过tf.get_default_graph()来获取当前节点所在的计算图。我们通过判断v1tensor所在的计算图和默认的计算图进行比较,发现v1的值处于默认的计算图上,由此也验证了:TensorFlow会自动维护一个默认的计算图,并将我们的节点添加到默认的计算图上。

图

我们可以看到默认的计算图上有三个节点,分别是v1v1节点,它们共同组成了add节点。

创建Graph

我们可以通过tf.Graph()新增计算图,并通过as_default()将变量和计算添加在当前的计算图中,最后通过Session的graph=计算图来计算指定的计算图。

# 新增计算图
new_graph = tf.Graph()
with new_graph.as_default():
    # 在新增的计算图中进行计算
    v1 = tf.constant(value=3, name='v1', shape=(1, 2), dtype=tf.float32)
    v2 = tf.constant(value=4, name='v2', shape=(1, 2), dtype=tf.float32)
    add = v1 + v2
#  通过graph=new_graph指定Session所在的计算图
with tf.Session(graph=new_graph) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(add))
# 在默认计算图中进行计算
v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
# 通过graph=tf.get_default_graph()指定Session所在默认的计算图
with tf.Session(graph=tf.get_default_graph()) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(add))

# 输出:[[7. 7.]]
# 输出:[[3. 3.]]

带有PlaceHolder的计算图

import  tensorflow as tf

a=tf.placeholder(dtype=tf.float32,shape=[1])
b=tf.placeholder(dtype=tf.float32,shape=[1])
c=a+b
init=tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(c,feed_dict={a:[2.1],b:[3.2]}))

[5.3]

多个图之间互不相干

import tensorflow as tf

g1=tf.Graph()
with g1.as_default():
    v=tf.get_variable("v",[1],initializer=tf.zeros_initializer(dtype=tf.float32))

g2=tf.Graph()
with g2.as_default():
    v=tf.get_variable("v",[1],initializer=tf.ones_initializer(dtype=tf.float32))

with tf.Session(graph=g1) as sess:
    tf.initialize_all_variables().run()
    with tf.variable_scope("",reuse=True):  # 当reuse=True时,tf.get_variable只能获取指定命名空间内的已创建的变量
         print(sess.run(tf.get_variable("v")))

with tf.Session(graph=g2) as sess:
    tf.initialize_all_variables().run()
    with tf.variable_scope("",reuse=True):  # 当reuse=True时,tf.get_variable只能获取指定命名空间内的已创建的变量
         print(sess.run(tf.get_variable("v")))

#输出:[0.]      [1.]

跟图相关的一些操作

1、根据 tensor name 来获取对应的tensor

对应的方法 get_tensor_by_name

import  tensorflow as tf

a=tf.placeholder(dtype=tf.float32,shape=[1],name='v1')
b=tf.placeholder(dtype=tf.float32,shape=[1],name='v2')
c=a+b
d=tf.add(a,b,name='add')
init=tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(c,feed_dict={a:[2.1],b:[3.2]}))

    test1=tf.get_default_graph().get_tensor_by_name('add:0')

    print(sess.run(test1,feed_dict={a:[1.0],b:[2.0]}))

[5.3]
[3.]

2、获取 operation 信息

对应的方法 get_operation_by_name

Q : with new_graph.as_default(): 在这里面运行tf.get_default_graph()获取的是什么图?

正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2021-01-05发表,共计2761字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码