TensorFlow Graph 操作

港控/mmm° 2023-03-01 09:46 36阅读 0赞

转载请注明作者和出处: http://blog.csdn.net/john\_bh/

文章目录

    • 1.创建图
    • 2.获取张量
    • 3.获取节点操作

1.创建图

在tensorflow中,一个程序默认是建立一个图的,除了系统自动建立图以外,还可以手动建立图,并做一些其他的操作。

tf.Graph() 表示实例化了一个类,一个用于 tensorflow 计算和表示用的数据流图,通俗来讲就是:在代码中添加的操作(画中的结点)和数据(画中的线条)都是画在纸上的“画”,而图就是呈现这些画的纸,你可以利用很多线程生成很多张图,但是默认图就只有一张
tf.Graph() 函数非常重要,注意体现在两个方面:

  1. 通过tensorboard用图形化界面展示出来流程结构;
  2. 整合一段代码为一个整体存在于一个图中。

tf.Graph().as_default() 表示将这个类实例,也就是新生成的图作为整个 tensorflow 运行环境的默认图,如果只有一个主线程不写也没有关系,tensorflow 里面已经存好了一张默认图,可以使用 tf.get_default_graph() 来调用(显示这张默认纸),当你有多个线程就可以创造多个tf.Graph(),就是你可以有一个画图本,有很多张图纸,这时候就会有一个默认图的概念了。

下面使用tf.Graph函数建立图,使用tf.get_default_graph函数来获取图,使用reset_default_graph对图进行重置。

  1. import tensorflow as tf
  2. import numpy as np
  3. c = tf.constant(1.5)
  4. g = tf.Graph()
  5. with g.as_default():
  6. c1 = tf.constant(2.0)
  7. print(c1.graph)
  8. print(g)
  9. print(c.graph)
  10. g2 = tf.get_default_graph()
  11. print(g2)
  12. tf.reset_default_graph()
  13. g3 = tf.get_default_graph()
  14. print(g3)

根据上述的运行结果,c是在刚开始的默认图中建立的,所以打印的结果就是13376A1FE10,和g2获取的默认图的值是一样的,然后使用tf.Graph建立了一个新的图,并添加了变量c1,最后又对图进行了重置,替代了原来的默认图。

在使用reset_default_graph()函数的时候,要保证当前图中资源都已经全部进行了释放,否则将会报错。

2.获取张量

我们可以在图中通过名字得到其对应的元素,比如获取图中的变量和OP等元素。

  1. import tensorflow as tf
  2. import numpy as np
  3. g = tf.Graph()
  4. with g.as_default():
  5. c1 = tf.constant(2.5, name='c1_constant')
  6. c2 = tf.Variable(1.5, dtype=tf.float32, name='c2_constant')
  7. add = tf.multiply(c1, c2, name='op_add')
  8. c_1 = g.get_tensor_by_name(name='c1_constant:0')
  9. c_2 = g.get_tensor_by_name(name='c2_constant:0')
  10. c_3 = g.get_tensor_by_name(name='op_add:0')
  11. print(c_1)
  12. print(c_2)
  13. print(c_3)

在进行测试时,我们为元素添加了变量名,在设置变量名的时候,设置好的名字会自动添加后面的:0字符。一般我们可以将名字打印出来,在将打印好的名字进行回填。

3.获取节点操作

获取节点操作OP的方法和获取张量的方法非常类似,使用get_operation_by_name.下面是运行实例:

  1. import tensorflow as tf
  2. import numpy as np
  3. a = tf.constant([[1.0, 2.0]])
  4. b = tf.constant([[1.0], [3.0]])
  5. tensor_1 = tf.matmul(a, b, name='matmul_1')
  6. with tf.Session() as sess:
  7. sess.run(tf.global_variables_initializer())
  8. t1 = tf.get_default_graph().get_operation_by_name(name='matmul_1')
  9. t2 = tf.get_default_graph().get_tensor_by_name(name='matmul_1:0')
  10. print(t1)
  11. print('t1: ', sess.run(t1))
  12. print('t2: ', sess.run(t2))

在上述的代码中,定义了一个OP操作,命名为matmul_1,在运行时我们将op打印出来,在使用名字后面加上:0我们就能得到OP运算的结果的tensor,注意这两者的区别。

我们还可以通过get_opreations函数获取图中的所有信息。此外,我们还可以使用tf.Grapg.as_graph_element函数将传入的对象返回为张量或者op。该函数具有验证和转换功能。

发表评论

表情:
评论列表 (有 0 条评论,36人围观)

还没有评论,来说两句吧...

相关阅读