zoukankan      html  css  js  c++  java
  • 1.5神经网络可视化显示(matplotlib)

    神经网络训练+可视化显示

    #添加隐层的神经网络结构+可视化显示
    import tensorflow as tf
    
    def add_layer(inputs,in_size,out_size,activation_function=None):
        #定义权重--随机生成inside和outsize的矩阵
        Weights=tf.Variable(tf.random_normal([in_size,out_size]))
        #不是矩阵,而是类似列表
        biaes=tf.Variable(tf.zeros([1,out_size])+0.1)
        Wx_plus_b=tf.matmul(inputs,Weights)+biaes
        if activation_function is  None:
            outputs=Wx_plus_b
        else:
            outputs=activation_function(Wx_plus_b)
        return outputs
    
    import numpy as np
    x_data=np.linspace(-1,1,300)[:,np.newaxis] #300行数据
    noise=np.random.normal(0,0.05,x_data.shape)
    y_data=np.square(x_data)-0.5+noise
    #None指定sample个数,这里不限定--输出属性为1
    xs=tf.placeholder(tf.float32,[None,1])  #这里需要指定tf.float32,
    ys=tf.placeholder(tf.float32,[None,1])
    
    #建造第一层layer
    #输入层(1)
    l1=add_layer(xs,1,10,activation_function=tf.nn.relu)
    #隐层(10)
    prediction=add_layer(l1,10,1,activation_function=None)
    #输出层(1)
    #预测prediction
    loss=tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),
                       reduction_indices=[1])) #平方误差
    train_step=tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    
    init=tf.initialize_all_variables()
    sess=tf.Session()
    #直到执行run才执行上述操作
    sess.run(init)
    
    
    import matplotlib.pyplot as plt
    fig=plt.figure()
    ax=fig.add_subplot(111)
    ax.scatter(x_data,y_data)
    plt.ion() #图像会连续显示
    #plt.show()不会终止整个函数。在2.x时候使用plt.show(block=False)
    plt.show()
    
    
    for i in range(1000):
        #这里假定指定所有的x_data来指定运算结果
        sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
        if i%50:
            # print (sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
            try:
                #忽略第一次的错误
                ax.lines.remove(lines[0]) #在图片中去掉lines的第1条线段,不然线会混乱
            except Exception:
                prediction_value=sess.run(prediction,feed_dict={xs:x_data})
                lines=ax.plot(x_data,prediction_value,'r-',lw=5)
                # ax.lines.remove(lines[0]) 移动上上面,先移除第一条线
                plt.pause(0.2) #每次暂停0.2s

    显示:

  • 相关阅读:
    Redis 安装(Windows)
    etcd简介与应用场景
    Nginx+SignalR+Redis(二)windows
    Nginx+SignalR+Redis(一)windows
    Windows 版MongoDB 复制集Replica Set 配置
    走进异步世界async、await
    认识和使用Task
    进程、应用程序域、线程的相互关系
    ASP.NET Core实现类库项目读取配置文件
    用idea做springboot开发,设置thymeleaf时候,新手容易忽略误区
  • 原文地址:https://www.cnblogs.com/jackchen-Net/p/8082562.html
Copyright © 2011-2022 走看看