zoukankan      html  css  js  c++  java
  • tensorflow学习笔记(二)常量、变量、占位符、会话

    常量、变量、占位符、会话是tensorflow编程的基础也是最常用到的东西,tensorflow中定义的变量、常量都是tensor(张量)类型。

    常量tf.constant()

    tensorflow中定义的变量、常量都是tensor(张量)类型常用是在运行过程中不会改变的量,如作线性回归Y = w*X + b ,知道一系列(X, Y) ,通过梯度下降找w和b,X和Y的值在程序运行时就不会去改变,只不断改变w和b去减小与真实值的误差,所以常量常用来表示输入输出。

    声明一个标量常量:

    t_1 = tf.constant(5)

    声明一个向量常量:

    t_2 = tf.constant([2, 3, 5])

    变量tf.Variable()

    作线性回归时要不断调整 w、b 做拟合,w和b就要声明为变量,所以变量常用来表示模型中的参数。

    声明一个M行N列,全为零的变量:

    b_1 = tf.Variable(tf.zeros([M, N], tf.float32))   # tf.zeros 创建全0张量,数字的类型是float32 ,然后用tf.Variable()将其变成变量

    声明一个呈正态分布的 均值是2(默认=0.0)标准差是4(默认是1.0)的2行3列张量:

    w_1 = tf.Variable(tf.random_nomarl([2, 3], mean=2.0, stddev=4, seed=2)

    会话Session()

    如同main()函数一样,会话是tensorflow程序的入口,tf 程序一般是先定义节点和节点的关系(运算),然后在会话中根据定义的运算自动算出结果 。

    import tensorflow as tf
    a = tf.constant([1, 2])
    b = tf.constant([2, 3])  # 定义常量a,b是1X2的张量
    c = tf.add(a, b)    # 定义 c=a+b
    d = tf.scalar_mul(tf.constant(2),c)  # 定义 d=2*c
    
    with tf.Session() as sess:
        print(sess.run(d))   # 根据定义的运算(图)计算 d 的值,并打印

    第四行第五行定义c、d的运算,但是在那里并没有直接得出c和d的结果,在会话中sess.run(d) 计算d ,它会自动根据前面定义好的运算计算出d的结果,而不需显示的先计算c  sess.run(c) 再计算d。如果在会话总只sess.run(c),程序就不会计算d了。

    占位符tf.placeholder()

    如名字一样,占位符就是先给变量占一个位,可以先不给变量赋具体值,先给变量一个位置,在会话运行时给变量传入具体的值。即占位符用于将数据提供给计算图。

    tf.placeholder(dtype,shape=None,name=None)  

    dtype是变量的数据类型,shape是变量的形状(几行几列),name是变量的名称。

    import tensorflow as tf
    import numpy as np
    a = np.array([1, 2])
    b = np.array([2, 3])   # 创建a,b两个1x2的ndarry变量
    X = tf.placeholder(tf.int32)
    Y = tf.placeholder(tf.int32)  # 定义两个占位符,张量形状可以不写,传入值的时候会自动判断
    
    c = tf.add(X, Y)
    d = tf.scalar_mul(tf.constant(2),c)
    
    with tf.Session() as sess:
        re = sess.run(d, feed_dict={X:a, Y:b})   # 在图计算时提供具体值
        print(re)
    

     会话在用feed_dict = { } 传入值的时候传入的不能是tf.constant()这种类型,必须是数组、np.ndarry等具体数值的类型。

  • 相关阅读:
    枚举工具类:封装判断是否存在这个枚举
    MYSQL插入emoji报错解决方法Incorrect string value
    文件大小转换带上单位工具类(文件byte自动转KBMBGB)
    mysql 统计七天数据并分组
    mybatis plus 和 druid 版本导致LocalDateTime 不兼容问题
    Layui弹框中select下拉列表赋值回显
    查看环境版本
    Linux 常用命令
    安装jdk14的坑
    modbus_tk解析
  • 原文地址:https://www.cnblogs.com/panda-blog/p/12305782.html
Copyright © 2011-2022 走看看