zoukankan      html  css  js  c++  java
  • 使用TenforFlow 搭建BP神经网络拟合二次函数

    使用简单BP神经网络拟合二次函数

    当拥有两层神经元时候,拟合程度明显比一层好
    并出现如下警告:

    C:Program FilesPython36libsite-packagesmatplotlibackend_bases.py:2453: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented warnings.warn(str, mplDeprecation)
    

    偶尔画出直线,不知为何
    当Learning Rate越高,或者层数越多
    画出直线或者罢工的几率就越大,愿评论区能出现解答
    罢工的原因: NaN
    直线:未知?,NN学错东西了

    Code(nn with 2 hidden layer)

    import tensorflow as tf
    import matplotlib.pyplot as plt
    import numpy as np
    
    # refenrence: http://blog.csdn.net/jacke121/article/details/74938031
    
    ''' numpy.linspace test  
    import numpy as np
    # list_random1 = np.linspace(-1, 1, 300)
    # list_random2 = np.linspace(-1, 1, 300)[:, np.newaxis]
    # print(list_random1)
    # print(list_random1)
    # print(np.shape(list_random1), np.shape(list_random2))
    #     (300)                   (300, 1)
    
    x_data = np.linspace(-1,1,10)[:, np.newaxis]
    noise = np.random.normal(0, 0.05, x_data.shape)
    y_data = np.square(x_data) - 0.5 + noise
    x_data2 = tf.random_uniform([10, 1], -1, 1)
    y_data2 = tf.square(x_data2) - tf.random_normal(x_data2.shape, 2)
    print(x_data, '
    
    ', x_data.shape, '
    
    ')     # (10, 1) 
    print(y_data, '
    
    ', y_data.shape, '
    
    ')  
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(x_data2))
        print(x_data2)                              # Tensor("random_uniform:0", shape=(10, 1), dtype=float32)
        print(sess.run(y_data2))
        print(y_data2)
    '''
    
    
    
    # add layer function
    def add_layer(inputs, input_size, output_size, activation_function = None):
        # add one more layer and return the output of this layer
        Weights = tf.Variable(tf.random_normal([input_size, output_size]))
        biases = tf.Variable(tf.zeros([1, output_size]) + 0.1)
        outputs = tf.matmul(inputs, Weights) + biases
        if activation_function is not None:
            outputs = activation_function(outputs)
        return outputs
    
    
    
    # creat data
    # x_data: random[-1, 1) shape=[100,1] uniform distribute
    # y_data: normal distribute(2) of random numbers, shape = [100, 1]
    x_data = np.linspace(-1,1,100)[:, np.newaxis]
    noise = np.random.normal(0, 0.05, x_data.shape)
    y_data = np.square(x_data) - 0.5 + noise
    
    # define placehold for inputs of nn
    x = tf.placeholder(tf.float32, [None, 1])
    y = tf.placeholder(tf.float32, [None, 1])
    
    # hidden layer
    layer1 = add_layer(x, 1, 10, activation_function=tf.nn.relu)
    # hidden layer2
    layer2 = add_layer(layer1, 10, 10, activation_function=tf.nn.relu)
    # output layer
    output_layer = add_layer(layer2, 10, 1, activation_function=None)
    
    # define loss for nn
    loss = tf.reduce_mean(tf.reduce_sum(tf.square(y - output_layer),
                         reduction_indices=[1]))
    train = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
    
    # visualize the result
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.scatter(x_data, y_data)
    plt.ion()
    plt.show()
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for i in range(1, 500):
            sess.run(train, feed_dict={x: x_data, y: y_data})
            # visualize the result
            if i%20 == 0:
                try:
                    ax.lines.remove(lines[0])
                except Exception:
                    pass
                output = sess.run(output_layer, feed_dict={x: x_data})
                lines = ax.plot(x_data, output, 'r-', lw=5)
                plt.pause(0.1)
    
    plt.pause(100)
  • 相关阅读:
    【Azure Redis Cache】对StackExchange.Redis IOCP错误消息的解读
    【Azure Developer】使用REST API获取Activity Logs、传入Data Lake的数据格式问题
    【Azure 存储服务】Blob中数据通过Stream Analytics导出到SQL/Cosmos DB
    【Azure Redis 缓存】Linux VM使用6380端口(SSL方式)连接Azure Redis (redis-cli & stunnel)
    【Azure 应用服务】在Azure App Service for Windows 中部署Java/NodeJS/Python项目时,web.config的配置模板内容
    【Azure Service Bus】 Service Bus如何确保消息发送成功,发送端是否有Ack机制 
    领域驱动实践总结(基本理论总结与分析+架构分析与代码设计+具体应用设计分析V)
    Java三元表达式中的陷阱
    Java有陷阱——慎用入参做返回值
    Eclipse中安装反编译工具Fernflower(Enhanced Class Decompiler)
  • 原文地址:https://www.cnblogs.com/liyuquan/p/7441855.html
Copyright © 2011-2022 走看看