zoukankan      html  css  js  c++  java
  • TensorFlow从入门到理解(六):可视化梯度下降

    运行代码:

    import tensorflow as tf
    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    LR = 0.1
    REAL_PARAMS = [1.2, 2.5]
    INIT_PARAMS = [[5, 4],
                   [5, 1],
                   [2, 4.5]][2]
    
    x = np.linspace(-1, 1, 200, dtype=np.float32)   # x data
    
    y_fun = lambda a, b: np.sin(b*np.cos(a*x))
    tf_y_fun = lambda a, b: tf.sin(b*tf.cos(a*x))
    
    noise = np.random.randn(200)/10
    y = y_fun(*REAL_PARAMS) + noise         # target
    
    # tensorflow graph
    a, b = [tf.Variable(initial_value=p, dtype=tf.float32) for p in INIT_PARAMS]
    pred = tf_y_fun(a, b)
    mse = tf.reduce_mean(tf.square(y-pred))
    train_op = tf.train.GradientDescentOptimizer(LR).minimize(mse)
    
    a_list, b_list, cost_list = [], [], []
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for t in range(400):
            a_, b_, mse_ = sess.run([a, b, mse])
            a_list.append(a_); b_list.append(b_); cost_list.append(mse_)    # record parameter changes
            result, _ = sess.run([pred, train_op])                          # training
    
    
    # visualization codes:
    print('a=', a_, 'b=', b_)
    plt.figure(1)
    plt.scatter(x, y, c='b')    # plot data
    plt.plot(x, result, 'r-', lw=2)   # plot line fitting
    # 3D cost figure
    fig = plt.figure(2); ax = Axes3D(fig)
    a3D, b3D = np.meshgrid(np.linspace(-2, 7, 30), np.linspace(-2, 7, 30))  # parameter space
    cost3D = np.array([np.mean(np.square(y_fun(a_, b_) - y)) for a_, b_ in zip(a3D.flatten(), b3D.flatten())]).reshape(a3D.shape)
    ax.plot_surface(a3D, b3D, cost3D, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
    ax.scatter(a_list[0], b_list[0], zs=cost_list[0], s=300, c='r')  # initial parameter place
    ax.set_xlabel('a'); ax.set_ylabel('b')
    ax.plot(a_list, b_list, zs=cost_list, zdir='z', c='r', lw=3)    # plot 3D gradient descent
    plt.show()

    运行结果:

  • 相关阅读:
    memory consistency
    网页基础
    ECC
    RSA
    argparse模块
    009-MySQL循环while、repeat、loop使用
    001-mac搭建Python开发环境、Anaconda、zsh兼容
    013-在 Shell 脚本中调用另一个 Shell 脚本的三种方式
    012-Shell 提示确认(Y / N,YES / NO)
    014-docker-终端获取 docker 容器(container)的 ip 地址
  • 原文地址:https://www.cnblogs.com/darklights/p/9939339.html
Copyright © 2011-2022 走看看