Tensorflow详解(一)

本文深入解析TensorFlow计算图的基本概念,包括节点、边和张量的角色与功能,以及如何创建和管理多个计算图。同时,介绍了计算图中的集合管理机制,如变量、训练变量、日志变量等,并演示了如何将个体加入集合及获取集合内容。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

初识计算图

 

计算图——tensorflow程序的计算过程可以表示为一个计算图  

    节点(Node)——计算图中的每一个运算操作,可以有任意个输入和输出

    边(Edge)——两个节点直接的数据流动,或起依赖控制的作用

    张量(Tensor)——边中流动(Flow)的数据

 

图的创建

 1.系统自动维护一个默认的计算图

 

 2.tf.Graph()创建一个计算图(如果需要多个图来完成工作时)

   不同计算图上的张量和节点(运算)都不会共享,也就是说,我们不能在某一计算图上调用其他计算图的成员

import tensorflow as tf

# 使用Graph()函数创建一个计算图
g1 = tf.Graph()
with g1.as_default():  # 将定义的计算图使用as_default()函数设为默认

    # 创建计算图中的变量并设置初始值为
    a = tf.get_variable("a", [2], initializer=tf.ones_initializer())
    b = tf.get_variable("b", [2], initializer=tf.zeros_initializer())

# 使用Graph()函数创建另一个计算图
g2 = tf.Graph()
with g2.as_default():
    a = tf.get_variable("a", [2], initializer=tf.zeros_initializer())
    b = tf.get_variable("b", [2], initializer=tf.ones_initializer())

with tf.Session(graph=g1) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("a")))
        print(sess.run(tf.get_variable("b")))
        # 打印[1. 1.]
        #    [0. 0.]

with tf.Session(graph=g2) as sess:
    tf.global_variables_initializer().run()
    with tf.variable_scope("", reuse=True):
        print(sess.run(tf.get_variable("a")))
        print(sess.run(tf.get_variable("b")))
        # 打印[0. 0.]
        #    [1. 1.]

3.集合(collection)

对于每一个计算图,Tensorflow通过5个默认集合管理其中不同类型的个体(所谓个体可以是张量,变量或队列等) 

TensorFlow维护的默认集合以及内容
tf.GraphKeys.VARIABLE所有变量

tf.GraphKeys.TRAINABLE_VARIABLES

可学习(训练)的变量(一般为神经网络中的参数)

tf.GraphKeys.SUMMARIES

日志生成相关的变量

tf.GraphKeys.QUEUE_RUNNERS

处理输入的QUEUE_RUNNERS

tf.GraphKeys.MOVING_AVERAGE_VARIABLES

所有计算了滑动平均值的变量

3.1 将个体加入一个或多个集合中;自定义一个name集合

      tf.add_to_collection(name, value)

3.2 获取一个集合中的所有个体

     tf.get_collection(key, scope=None)

                 key:用来获取一个名称是‘key’的集合中的所有元素,返回的是一个列表,列表的顺序是按照变量放入集合中的先后;

                 scope:参数可选,表示的是名称空间(名称域),如果指定,就返回名称域中所有放入‘key’的变量的列表,不指定则                              返回所有变量

3.3列表内容相加

      tf.add_n()

import tensorflow as tf
import numpy as np

with tf.variable_scope('one'):
    v1 = tf.get_variable(name='v1', shape=[1], initializer=tf.constant_initializer(0))
    tf.add_to_collection('loss', v1)
with tf.variable_scope('two'):
    v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(2))
    tf.add_to_collection('loss', v2)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    #获取集合loss中所有对象
    print (tf.get_collection('loss'))
    # 获取集合loss中属于one变量空间的对象
    print(tf.get_collection('loss',scope='one'))
    print(sess.run(tf.add_n(tf.get_collection('loss'))))

输出结果

[<tf.Variable 'one/v1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'two/v2:0' shape=(1,) dtype=float32_ref>]
[<tf.Variable 'one/v1:0' shape=(1,) dtype=float32_ref>]
[2.]

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值