zoukankan      html  css  js  c++  java
  • 使用tensorflow预测函数的参数值(a simple task)

    已知x1,x2,x3,y
    根据y = aax1 + bbx2 + abx3
    预测参数a和b

    import tensorflow as tf
    import numpy as np
    
    x1_data = np.random.rand(100).astype(np.float32)
    x2_data = np.random.rand(100).astype(np.float32)
    x3_data = np.random.rand(100).astype(np.float32)
    y_data = x1_data*0.5*0.5 + x2_data*0.8*0.8 + x3_data*0.5*0.8
    
    a = tf.Variable(tf.random_uniform([1]))
    b = tf.Variable(tf.random_uniform([1]))
    
    y = a*a*x1_data + b*b*x2_data + a*b*x3_data
    
    loss = tf.reduce_mean(tf.square(y - y_data))
    optimizer = tf.train.AdamOptimizer(0.1)
    train = optimizer.minimize(loss)
    
    init = tf.initialize_all_variables()
    
    with tf.Session() as sess:
        sess.run(init)
        for step in range(202):
            sess.run(train)
            if step % 20 ==0:
                print(step,sess.run(a),sess.run(b),sess.run(loss))
    

    Output:
    0 [0.16722116] [0.4231345] 0.287351
    20 [0.46662363] [0.81065917] 0.00041220474
    40 [0.51124895] [0.83035856] 0.0021522618
    60 [0.50623244] [0.79908705] 2.3531691e-05
    80 [0.49737427] [0.79750854] 2.718182e-05
    100 [0.50043434] [0.80122584] 3.336514e-06
    120 [0.5001132] [0.799731] 5.6176066e-08
    140 [0.4998217] [0.7999378] 5.3850208e-08
    160 [0.50005966] [0.8000353] 8.847866e-09
    180 [0.50001425] [0.8000164] 1.0047848e-09
    200 [0.49999878] [0.79999965] 2.2931767e-12

  • 相关阅读:
    各种有趣言论收集
    人类未来进化方向恶考
    mysql 列所有表行数
    恩,有那么一个人
    00后厉害哇
    。。。。
    放弃微博,继续回来写月经
    嘿,大家还好吗
    git
    require js
  • 原文地址:https://www.cnblogs.com/bernieloveslife/p/10216194.html
Copyright © 2011-2022 走看看