zoukankan      html  css  js  c++  java
  • tensorflow常用函数解释

    从二维数组中选一个矩形

    import tensorflow as tf
    data = [[1,2,3,4,5,6,7,8],[11,12,13,14,15,16,17,18]]
    x = tf.strided_slice(data,[0,0],[2,4])
    with tf.Session() as sess:
        print(sess.run(x))
    

    numpy array转tensor

    import tensorflow as tf
    import numpy as np
    
    A = list([1, 2, 3])
    B = np.array([1, 2, 3])
    C = tf.convert_to_tensor(A)
    D = tf.convert_to_tensor(B)
    
    with tf.Session() as sess:
        print(type(A))
        print(type(B))
        print(C.eval())
        print(D.eval())
    

    tf.train.Supervisor 用法

    http://www.cnblogs.com/zhouyang209117/p/7088051.html
    

    使用训练好的模型

    import tensorflow as tf
    import numpy as np
    import os
    log_path = r"D:Sourcemodellinear"
    log_name = "linear.ckpt"
    # Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
    x_data = np.random.rand(100).astype(np.float32)
    y_data = x_data * 0.1 + 0.3
    
    # Try to find values for W and b that compute y_data = W * x_data + b
    # (We know that W should be 0.1 and b 0.3, but TensorFlow will
    # figure that out for us.)
    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
    b = tf.Variable(tf.zeros([1]))
    y = W * x_data + b
    
    # Minimize the mean squared errors.
    loss = tf.reduce_mean(tf.square(y - y_data))
    optimizer = tf.train.GradientDescentOptimizer(0.5)
    train = optimizer.minimize(loss)
    
    # Before starting, initialize the variables.  We will 'run' this first.
    saver = tf.train.Saver()
    init = tf.global_variables_initializer()
    
    # Launch the graph.
    sess = tf.Session()
    sess.run(init)
    
    if len(os.listdir(log_path)) != 0:  # 已经有模型直接读取
        saver.restore(sess, os.path.join(log_path, log_name))
    for step in range(201):
        sess.run(train)
        if step % 20 == 0:
            print(step, sess.run(W), sess.run(b))
    saver.save(sess, os.path.join(log_path, log_name))
    

    数组添加一列

    import tensorflow as tf
    a = [[1], [2], [3]]
    b = [[1, 2, 3, 4], [1, 3, 6, 7], [4, 2, 1, 6]]
    sess = tf.Session()
    print(sess.run(tf.concat([a, b], 1)))
    

    结果为:

    [[1 1 2 3 4]
     [2 1 3 6 7]
     [3 4 2 1 6]]
    

    一维数组合成二维数组,二维数组拆分成一维数组

    import tensorflow as tf
    
    x = tf.constant([1, 2, 3])
    y = tf.constant([4, 5, 6])
    z = tf.constant([7, 8, 9])
    
    p = tf.stack([x, y, z])
    
    sess = tf.Session()
    print(sess.run(p))
    print(sess.run(tf.unstack(p, num=3, axis=0)))
    

    tf.gather从数组中选出几个元素

    import tensorflow as tf
    sess = tf.Session()
    params = tf.constant([6, 3, 4, 1, 5, 9, 10])
    indices = tf.constant([2, 0, 2, 5])
    output = tf.gather(params, indices)
    print(sess.run(output))
    sess.close()
    

    expand_dims

    它指定维前面增加一维

    # coding:utf8
    import tensorflow as tf
    import numpy as np
    sess = tf.Session()
    data = tf.constant([[1, 2, 1], [3, 1, 1]])
    print(sess.run(tf.shape(data)))  #(2,3)
    d_1 = tf.expand_dims(data, 0)  # (1,2,3)
    d_1 = tf.expand_dims(d_1, 2)  # (1,2,1,3)
    d_1 = tf.expand_dims(d_1, -1)  # (1,2,1,3,1)
    d_1 = tf.expand_dims(d_1, -1)  # (1,2,1,3,1,1)
    print(sess.run(tf.shape(d_1)))
    d_2 = d_1
    print(sess.run(tf.shape(tf.squeeze(d_1))))
    print(sess.run(tf.shape(tf.squeeze(d_2, [2, 4]))))
    

    squeeze

    和expand_dims的作用相反,去掉指定维,指定维的长度必须为1.

    a = np.random.random((50, 1, 64))
    b = tf.convert_to_tensor(a)
    c = tf.squeeze(b, squeeze_dims=(1,))
    with tf.Session() as sess:
        print(c.eval())
    

    split

    把二维数组拆分成多个一维数组.

    # coding:utf8
    import tensorflow as tf
    sess = tf.Session()
    input = tf.random_normal([5, 30])
    print(sess.run(tf.shape(input))[0] / 5)
    split0, split1, split2, split3, split4 = tf.split(0, 5, input)  # 二维数组变成多个一维数组
    print(sess.run(tf.shape(split0)))
    

    按第0维拆分,拆分成5个(1,30)的数组.

    列出所有训练的变量

    for var in tf.trainable_variables():
            print("name={},shape={}".format(var.name, var.get_shape()))
    
  • 相关阅读:
    xunjian.sh
    192.168.50.235配置
    自动备份并删除旧日志
    bg和fg命令
    linux之sed用法
    正则表示第二行,第二列
    linux下redis安装
    Hma梳理
    linux 系统监控、诊断工具之 lsof 用法简介
    java的基本数据类型有八种
  • 原文地址:https://www.cnblogs.com/zhouyang209117/p/7142601.html
Copyright © 2011-2022 走看看