zoukankan      html  css  js  c++  java
  • tensorflow入门03——搭建两层的简单神经网络

    搭建一个简单的神经网络对mnist数据集中的手写数字数据集进行训练和测试。

    输入的每张数据包含784个像素点,第一层为784行256列的矩阵,第二层是256行128列的矩阵,输出层则将结果转换为10个输出值,代表手写数字的10种分类结果,每层有一个权重值weight和偏置bias

    代码实现

    #搭建两层的神经网络
    import numpy as np
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
    import matplotlib.pyplot as plt
    import input_data
    
    #加载数据集
    minst=input_data.read_data_sets('data/data/', one_hot=True)
    
    #设置参数
    n_hidden_1=256 #第一层输出
    n_hidden_2=128 #第二层输出
    n_input=784 #输入像素点
    n_classes=10 #分类结果
    
    #输入输出
    x=tf.placeholder("float",[None,n_input])
    y=tf.placeholder("float",[None,n_classes])
    
    #神经网络参数
    stddev=0.1
    weights={
        'w1':tf.Variable(tf.random_normal([n_input,n_hidden_1],stddev=stddev)),  #高斯初始化
        'w2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2],stddev=stddev)),
        'out':tf.Variable(tf.random_normal([n_hidden_2,n_classes],stddev=stddev))
    }
    biases={
        'b1':tf.Variable(tf.random_normal([n_hidden_1])),
        'b2':tf.Variable(tf.random_normal([n_hidden_2])),
        'out':tf.Variable(tf.random_normal([n_classes]))
    
    }
    print('NetWork Ready')
    
    def multilayer_perceptron(_X,_weights,_biases):
        layer_1=tf.nn.sigmoid(tf.add(tf.matmul(_X,_weights['w1']),biases['b1']))   #sigmoid激活函数
        layer_2=tf.nn.sigmoid(tf.add(tf.matmul(layer_1,_weights['w2']),_biases['b2']))
        return (tf.matmul(layer_2,_weights['out'])+_biases['out'])
    
    #预测
    pred=multilayer_perceptron(x,weights,biases)
    
    #损失和优化器
    learning_rate=0.01
    cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(y,pred))
    optm=tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
    corr=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
    accr=tf.reduce_mean(tf.cast(corr,"float"))
    
    #初始化
    init=tf.global_variables_initializer()
    print('function ready')
    
    #训练参数
    training_epochs=20
    batch_size=100
    display_step=4
    #开始训练
    sess=tf.Session()
    sess.run(init)
    feeds={}
    for epoch in range(training_epochs):
        avg_cost=0.
        total_batch=int(minst.train.num_examples/batch_size)
        for i in range(total_batch):
            batch_xs,batch_ys=minst.train.next_batch(batch_size)
            feeds={x:batch_xs,y:batch_ys}
            sess.run(optm,feed_dict=feeds)
            avg_cost+=sess.run(cost,feed_dict=feeds)
        avg_cost=avg_cost/total_batch
        if(epoch+1)%display_step==0:
            print('Epoch:%03d/%03d cost:%.9f'%(epoch+1,training_epochs,avg_cost))
            train_acc=sess.run(accr,feed_dict=feeds)
            print('TRAIN ACCURACY:%.3f'%(train_acc))
            feeds={x:minst.test.images,y:minst.test.labels}
            test_acc=sess.run(accr,feed_dict=feeds)
            print('TEST ACCURACY:%.3f'%(test_acc))
    print('训练完成')

     训练结果

  • 相关阅读:
    设计模式学习总结系列应用实例
    【研究课题】高校特殊学生的发现及培养机制研究
    Linux下Oracle11G RAC报错:在安装oracle软件时报file not found一例
    python pro practice
    openstack python sdk list tenants get token get servers
    openstack api
    python
    git for windows
    openstack api users list get token get servers
    linux 流量监控
  • 原文地址:https://www.cnblogs.com/XiaoGao128/p/14281464.html
Copyright © 2011-2022 走看看