zoukankan      html  css  js  c++  java
  • 1.5神经网络可视化显示(matplotlib)

    神经网络训练+可视化显示

    #添加隐层的神经网络结构+可视化显示
    import tensorflow as tf
    
    def add_layer(inputs,in_size,out_size,activation_function=None):
        #定义权重--随机生成inside和outsize的矩阵
        Weights=tf.Variable(tf.random_normal([in_size,out_size]))
        #不是矩阵,而是类似列表
        biaes=tf.Variable(tf.zeros([1,out_size])+0.1)
        Wx_plus_b=tf.matmul(inputs,Weights)+biaes
        if activation_function is  None:
            outputs=Wx_plus_b
        else:
            outputs=activation_function(Wx_plus_b)
        return outputs
    
    import numpy as np
    x_data=np.linspace(-1,1,300)[:,np.newaxis] #300行数据
    noise=np.random.normal(0,0.05,x_data.shape)
    y_data=np.square(x_data)-0.5+noise
    #None指定sample个数,这里不限定--输出属性为1
    xs=tf.placeholder(tf.float32,[None,1])  #这里需要指定tf.float32,
    ys=tf.placeholder(tf.float32,[None,1])
    
    #建造第一层layer
    #输入层(1)
    l1=add_layer(xs,1,10,activation_function=tf.nn.relu)
    #隐层(10)
    prediction=add_layer(l1,10,1,activation_function=None)
    #输出层(1)
    #预测prediction
    loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),
                       reduction_indices=[1])) #平方误差
    train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    
    init=tf.initialize_all_variables()
    sess=tf.Session()
    #直到执行run才执行上述操作
    sess.run(init)
    
    
    import matplotlib.pyplot as plt
    fig=plt.figure()
    ax=fig.add_subplot(111)
    ax.scatter(x_data,y_data)
    plt.ion() #图像会连续显示
    #plt.show()不会终止整个函数。在2.x时候使用plt.show(block=False)
    plt.show()
    
    
    for i in range(1000):
        #这里假定指定所有的x_data来指定运算结果
        sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
        if i%50:
            # print (sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
            try:
                #忽略第一次的错误
                ax.lines.remove(lines[0]) #在图片中去掉lines的第1条线段,不然线会混乱
            except Exception:
                prediction_value=sess.run(prediction,feed_dict={xs:x_data})
                lines=ax.plot(x_data,prediction_value,'r-',lw=5)
                # ax.lines.remove(lines[0]) 移动上上面,先移除第一条线
                plt.pause(0.2) #每次暂停0.2s

    显示:

  • 相关阅读:
    RabbitMQ 内存控制 硬盘控制
    Flannel和Docker网络不通定位问题
    kafka集群扩容后的topic分区迁移
    CLOSE_WAIT状态的原因与解决方法
    搭建Harbor企业级docker仓库
    Redis哨兵模式主从持久化问题解决
    mysql杂谈(爬坑,解惑,总结....)
    Linux的信号量(semaphore)与互斥(mutex)
    SIP协议的传输层原理&报文解析(解读rfc3581)(待排版) && opensips
    SIP协议的传输层原理&报文解析(解读RFC3261)(待排版)&&启动
  • 原文地址:https://www.cnblogs.com/jackchen-Net/p/8082562.html
Copyright © 2011-2022 走看看