zoukankan      html  css  js  c++  java
  • Tensorflow笔记——神经网络图像识别(一)前反向传播,神经网络八股

     

    第一讲:人工智能概述

       

     

    第三讲:Tensorflow框架

     

     

     

     

    前向传播:

    反向传播:

    总的代码:

    #coding:utf-8
    #1.导入模块,生成模拟数据集
    import tensorflow as tf
    import numpy as np #np为科学计算模块
    BATCH_SIZE = 8#表示一次喂入NN多少组数据,不能过大,会噎着
    seed = 23455
    
    #基于seed产生随机数
    rng = np.random.RandomState(seed)
    #随机数返回32*2列的矩阵,每行2个表示属性(体积和质量),作为输入数据集
    X = rng.rand(32,2)
    #从x这个矩阵中,取出每一行,判断如果和<1,y=1,否则,y=0
    #作为输入数据集的标签(正确答案)
    Y=[[int(x0+x1<1)] for (x0,x1) in X]
    print("X:
    ",X)
    print("Y:
    ",Y)
    
    #2.定义神经网络的输入,参数和输出,定义前向传播过程
    x = tf.placeholder(tf.float32,(None,2))
    y_ = tf.placeholder(tf.float32,(None,1))
    
    w1 = tf.Variable(tf.random_normal([2,3], stddev=1, seed=1))
    w2 = tf.Variable(tf.random_normal([3,1], stddev=1, seed=1))
    
    a = tf.matmul(x, w1)
    y = tf.matmul(a, w2)
    
    #3.定义损失函数及反向传播方法
    loss = tf.reduce_mean(tf.square(y-y_))
    train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss)#梯度下降
    #train_step = tf.train.MomentumOptimizer(0.001,0.9).minimize(loss)
    #train_step = tf.train.AdamOptimizer(0.001,0.9).minimize(loss)
    
    #4.生成会话,训练STEPS轮
    with tf.Session() as sess:
        init_op=tf.global_variables_initializer()
        sess.run(init_op)
        #输出目前还未训练的参数取值
        print("w1:
    ", sess.run(w1))
        print("w2:
    ", sess.run(w2))
        print("
    ")
    
    #训练模型
        STEPS=3000
        #训练3000轮,每次从训练集中挑选strart到end的数据,喂入数据
        for i in range(STEPS):
            start = (i*BATCH_SIZE)%32
            end = start + BATCH_SIZE
            sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
            if i%500 == 0:#每500次打印一轮
                total_loss = sess.run(loss,feed_dict={x:X,y_:Y})
                print("After %d training steps,loss on all data is %g" % (i,total_loss))
    
    
        #输出训练后的参数取值
        print("
    ")
        print("w1:
    ", sess.run(w1))
        print("w2:
    ", sess.run(w2))

     输出的结果:

    X:
     [[ 0.83494319  0.11482951]
     [ 0.66899751  0.46594987]
     [ 0.60181666  0.58838408]
     [ 0.31836656  0.20502072]
     [ 0.87043944  0.02679395]
     [ 0.41539811  0.43938369]
     [ 0.68635684  0.24833404]
     [ 0.97315228  0.68541849]
     [ 0.03081617  0.89479913]
     [ 0.24665715  0.28584862]
     [ 0.31375667  0.47718349]
     [ 0.56689254  0.77079148]
     [ 0.7321604   0.35828963]
     [ 0.15724842  0.94294584]
     [ 0.34933722  0.84634483]
     [ 0.50304053  0.81299619]
     [ 0.23869886  0.9895604 ]
     [ 0.4636501   0.32531094]
     [ 0.36510487  0.97365522]
     [ 0.73350238  0.83833013]
     [ 0.61810158  0.12580353]
     [ 0.59274817  0.18779828]
     [ 0.87150299  0.34679501]
     [ 0.25883219  0.50002932]
     [ 0.75690948  0.83429824]
     [ 0.29316649  0.05646578]
     [ 0.10409134  0.88235166]
     [ 0.06727785  0.57784761]
     [ 0.38492705  0.48384792]
     [ 0.69234428  0.19687348]
     [ 0.42783492  0.73416985]
     [ 0.09696069  0.04883936]]
    Y:
     [[1], [0], [0], [1], [1], [1], [1], [0], [1], [1], [1], [0], [0], [0], [0], [0], [0], [1], [0], [0], [1], [1], [0], [1], [0], [1], [1], [1], [1], [1], [0], [1]]
    w1:
     [[-0.81131822  1.48459876  0.06532937]
     [-2.4427042   0.0992484   0.59122431]]
    w2:
     [[-0.81131822]
     [ 1.48459876]
     [ 0.06532937]]
    
    
    After 0 training steps,loss on all data is 5.13118
    After 500 training steps,loss on all data is 0.429111
    After 1000 training steps,loss on all data is 0.409789
    After 1500 training steps,loss on all data is 0.399923
    After 2000 training steps,loss on all data is 0.394146
    After 2500 training steps,loss on all data is 0.390597
    
    
    w1:
     [[-0.70006633  0.9136318   0.08953571]
     [-2.3402493  -0.14641267  0.58823055]]
    w2:
     [[-0.06024267]
     [ 0.91956186]
     [-0.0682071 ]]

     

     

     

     

  • 相关阅读:
    GSM Arena 魅族mx四核评测个人翻译
    Oracle Exists用法|转|
    NC公有协同的实现原理|同13的QQ||更新总部往来协同|
    NC客商bd_custbank不可修改账号、名称但可修改默认银行并更新分子公司trigger
    试玩了plsql中test窗口declare声明变量|lpad函数||plsql sql command test window区别|
    使用windows live writer测试
    用友写insert on bd_custbank 触发器和自动更新单位名称2in1
    oracle触发器select into和cursor用法的区别
    |转|oracle中prior的用法,connect by prior,树形目录
    客商增加自动增加银行账户|搞定!||更新使用游标course写法|
  • 原文地址:https://www.cnblogs.com/caiyishuai/p/9467580.html
Copyright © 2011-2022 走看看