zoukankan      html  css  js  c++  java
  • 用tensorflow构建神经网络学习简单函数

    目标是学习(y=2x+3)
    建立一个5层的神经网络,用平方误差作为损失函数。
    代码如下:

    import tensorflow as tf
    import numpy as np
    import time
    
    x_size=200000
    dim=2
    x_data=np.random.random([x_size,dim]).astype('float32')
    y_data=2*x_data+3
    x_test=np.random.random([10,dim]).astype('float32')
    y_test=2*x_test+3
    
    train_x=tf.placeholder(tf.float32,shape=[None,dim])
    train_y=tf.placeholder(tf.float32,shape=[None,dim]) 
    
    weight1=tf.Variable(tf.truncated_normal([dim,40],stddev=0.1))
    b1=tf.Variable(tf.zeros([40])+0.1)
    h1=tf.nn.relu(tf.matmul(train_x,weight1)+b1)
    
    weight2=tf.Variable(tf.truncated_normal([40,40],stddev=0.1))
    b2=tf.Variable(tf.zeros([40])+0.1)
    h2=tf.nn.relu(tf.matmul(h1,weight2)+b2)
    
    weight3=tf.Variable(tf.truncated_normal([40,40],stddev=0.1))
    b3=tf.Variable(tf.zeros([40])+0.1)
    h3=tf.nn.relu(tf.matmul(h2,weight3)+b3)
    
    weight4=tf.Variable(tf.truncated_normal([40,40],stddev=0.1))
    b4=tf.Variable(tf.zeros([40])+0.1)
    h4=tf.nn.relu(tf.matmul(h3,weight4)+b4)
    
    weight5=tf.Variable(tf.truncated_normal([40,dim],stddev=0.1))
    b5=tf.Variable(tf.zeros([dim])+0.1)
    y_output=tf.nn.relu(tf.matmul(h4,weight5)+b5)
    
    loss=tf.reduce_mean(tf.square(train_y-y_output))
    optimizer=tf.train.GradientDescentOptimizer(0.5)
    train_step=optimizer.minimize(loss)
    
    t1=time.time()
    sess=tf.Session()
    sess.run(tf.global_variables_initializer())
    for i in range(2000):
        feed_train={
            train_x:x_data,
            train_y:y_data
        }
        if i%100==0:
            print('loss:',sess.run(loss,feed_dict=feed_train),end=',   ')
        sess.run(train_step,feed_dict=feed_train)
    print()
    t2=time.time()
    print('Total Time:',t2-t1)
    print('test') 
    for i in range(10):
        feed_test={train_x:x_test[i:i+1],train_y:y_test[i:i+1]}    
        print('y:       ',sess.run(train_y,feed_dict=feed_test))
        print('y_output:',sess.run(y_output,feed_dict=feed_test))
        print('loss:',sess.run(loss,feed_dict=feed_test))
    sess.close()
    

    结果:

    loss: 15.4106,   loss: 0.232037,   loss: 0.211914,   loss: 0.198133,   loss: 0.0544874,   loss: 0.0280089,   loss: 0.0211618,   loss: 0.0173591,   loss: 0.0109964,   loss: 0.00902615,   loss: 0.00815686,   loss: 0.00941989,   loss: 0.00619169,   loss: 0.00529554,   loss: 0.00506653,   loss: 0.00660528,   loss: 0.00382864,   loss: 0.00412649,   loss: 0.00610038,   loss: 0.00354737,   
    Total Time: 88.89598035812378
    test
    y:        [[ 4.46494102  4.53034449]]
    y_output: [[ 4.48269606  4.44468594]]
    loss: 0.00382631
    y:        [[ 3.21122026  4.36406898]]
    y_output: [[ 3.22117805  4.2706871 ]]
    loss: 0.00440967
    y:        [[ 3.58840036  4.41665506]]
    y_output: [[ 3.59200501  4.3375597 ]]
    loss: 0.00313453
    y:        [[ 3.49797821  4.21883869]]
    y_output: [[ 3.51356149  4.14429617]]
    loss: 0.00289971
    y:        [[ 3.75655651  4.35610151]]
    y_output: [[ 3.76163697  4.26597834]]
    loss: 0.004074
    y:        [[ 4.52173853  4.32090807]]
    y_output: [[ 4.53192806  4.2343545 ]]
    loss: 0.00379767
    y:        [[ 4.19067335  4.8417387 ]]
    y_output: [[ 4.20001888  4.73385048]]
    loss: 0.0058636
    y:        [[ 4.58287668  3.89965653]]
    y_output: [[ 4.59979439  3.84099913]]
    loss: 0.00186345
    y:        [[ 4.25389147  3.75640154]]
    y_output: [[ 4.23791742  3.69044876]]
    loss: 0.00230247
    y:        [[ 3.40870714  4.49888897]]
    y_output: [[ 3.41926885  4.42829704]]
    loss: 0.00254738
    

    可以看出在训练集上loss不断减小,最后下降到0.00354737,而在测试集上loss也在0.003左右。
    由于参数是随机设置的,有时候可能陷入局部最优中,多运行几次可以减少陷入局部最优的概率。

    将优化算法换成:

    optimizer=tf.train.AdamOptimizer()
    

    后的结果:

    loss: 15.6427,   loss: 0.197051,   loss: 0.174776,   loss: 0.164641,   loss: 0.15766,   loss: 0.131154,   loss: 0.0029341,   loss: 0.000404288,   loss: 0.000178629,   loss: 9.63827e-05,   loss: 5.74653e-05,   loss: 3.65505e-05,   loss: 2.44332e-05,   loss: 1.69916e-05,   loss: 1.22397e-05,   loss: 9.06447e-06,   loss: 6.86902e-06,   loss: 5.31113e-06,   loss: 4.16228e-06,   loss: 3.30907e-06,   
    Total Time: 89.90041589736938
    test
    y:        [[ 4.46494102  4.53034449]]
    y_output: [[ 4.46485758  4.53046322]]
    loss: 1.05304e-08
    y:        [[ 3.21122026  4.36406898]]
    y_output: [[ 3.21072125  4.36450434]]
    loss: 2.19271e-07
    y:        [[ 3.58840036  4.41665506]]
    y_output: [[ 3.58802533  4.41699553]]
    loss: 1.28282e-07
    y:        [[ 3.49797821  4.21883869]]
    y_output: [[ 3.49763799  4.2191186 ]]
    loss: 9.70489e-08
    y:        [[ 3.75655651  4.35610151]]
    y_output: [[ 3.75626636  4.35636234]]
    loss: 7.61112e-08
    y:        [[ 4.52173853  4.32090807]]
    y_output: [[ 4.52174997  4.32091379]]
    loss: 8.18545e-11
    y:        [[ 4.19067335  4.8417387 ]]
    y_output: [[ 4.19037819  4.84208441]]
    loss: 1.03317e-07
    y:        [[ 4.58287668  3.89965653]]
    y_output: [[ 4.58305788  3.89945245]]
    loss: 3.7242e-08
    y:        [[ 4.25389147  3.75640154]]
    y_output: [[ 4.25399828  3.75623488]]
    loss: 1.95912e-08
    y:        [[ 3.40870714  4.49888897]]
    y_output: [[ 3.40823555  4.49932337]]
    loss: 2.05551e-07
    
    使用RMSPropOptimizer,最小loss:0.33
    使用FtrlOptimizer,最小loss:0.17
    使用MomentumOptimizer(learning_rate=0.1,momentum=0.6),loss:4.47119e-06, 但是不是很稳定。
    
  • 相关阅读:
    5-2 bash 脚本编程之一 变量、变量类型等
    4-4 grep及正则表达式
    4-3 管理及IO重定向
    4-2 权限及权限管理
    CentOS7 发布 ASP.NET MVC 4 --- mono 4.6.0 + jexus 5.8.1
    CentOS7 安装 nginx
    Hibernate学习笔记--------4.查询
    Hibernate学习笔记--------3.缓存
    Hibernate学习笔记--------2.一多|多多的CRUD
    Hibernate学习笔记--------1.单表操作
  • 原文地址:https://www.cnblogs.com/sandy-t/p/6916766.html
Copyright © 2011-2022 走看看