zoukankan      html  css  js  c++  java
  • TensorFlow实现梯度下降

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Mon Oct 15 17:38:39 2018
     4 
     5 @author: zhen
     6 """
     7 
     8 import tensorflow as tf
     9 import numpy as np
    10 from sklearn.datasets import fetch_california_housing
    11 from sklearn.preprocessing import StandardScaler
    12 
    13 n_epochs = 10000
    14 learning_rate = 0.01
    15 
    16 housing = fetch_california_housing(data_home="C:/Users/zhen/.spyder-py3/data", download_if_missing=True)
    17 m, n = housing.data.shape
    18 housing_data_plus_bias = np.c_[np.ones((m, 1)), housing.data]
    19 # 归一化
    20 scaler= StandardScaler().fit(housing_data_plus_bias)
    21 scaled_housing_data_plus_bias = scaler.transform(housing_data_plus_bias)
    22 # 创建常量
    23 x = tf.constant(scaled_housing_data_plus_bias, dtype=tf.float32, name='x')
    24 y = tf.constant(housing.target.reshape(-1, 1), dtype=tf.float32, name='y')
    25 # 创建随机数
    26 theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0), name='theta')
    27 # 矩阵乘
    28 y_pred = tf.matmul(x, theta, name="predictions") 
    29 
    30 error = y_pred - y
    31 # 求平均值
    32 mse = tf.reduce_mean(tf.square(error), name="mse")
    33 """
    34 # 求梯度
    35 gradients = tf.gradients(mse, [theta])[0]
    36 # 赋值
    37 training_op = tf.assign(theta, theta - learning_rate * gradients)
    38 """
    39 # 梯度下降
    40 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    41 training_op = optimizer.minimize(mse)
    42 
    43 init = tf.global_variables_initializer()
    44 
    45 with tf.Session() as sess:
    46     sess.run(init)
    47     
    48     for epoch in range(n_epochs):
    49         if epoch % 100 == 0:
    50             print("Epoch", epoch, "MSE = ", mse.eval())
    51         sess.run(training_op)
    52         
    53     best_theta = theta.eval()
    54     print(best_theta)

    结果:

    Epoch 0 MSE =  9.128207
    Epoch 100 MSE =  4.893214
    Epoch 200 MSE =  4.8329406
    Epoch 300 MSE =  4.824335
    Epoch 400 MSE =  4.8187895
    Epoch 500 MSE =  4.814753
    Epoch 600 MSE =  4.811796
    Epoch 700 MSE =  4.8096204
    Epoch 800 MSE =  4.808017
    Epoch 900 MSE =  4.806835
    Epoch 1000 MSE =  4.805955
    Epoch 1100 MSE =  4.805301
    Epoch 1200 MSE =  4.8048124
    Epoch 1300 MSE =  4.804449
    Epoch 1400 MSE =  4.804172
    Epoch 1500 MSE =  4.803962
    Epoch 1600 MSE =  4.8038034
    Epoch 1700 MSE =  4.803686
    Epoch 1800 MSE =  4.8035927
    Epoch 1900 MSE =  4.80352
    Epoch 2000 MSE =  4.8034678
    Epoch 2100 MSE =  4.803425
    Epoch 2200 MSE =  4.8033857
    Epoch 2300 MSE =  4.803362
    Epoch 2400 MSE =  4.803341
    Epoch 2500 MSE =  4.8033247
    Epoch 2600 MSE =  4.80331
    Epoch 2700 MSE =  4.8033013
    Epoch 2800 MSE =  4.8032923
    Epoch 2900 MSE =  4.8032856
    Epoch 3000 MSE =  4.8032804
    Epoch 3100 MSE =  4.803273
    Epoch 3200 MSE =  4.803271
    Epoch 3300 MSE =  4.8032694
    Epoch 3400 MSE =  4.803267
    Epoch 3500 MSE =  4.8032637
    Epoch 3600 MSE =  4.8032603
    Epoch 3700 MSE =  4.803259
    Epoch 3800 MSE =  4.803259
    Epoch 3900 MSE =  4.8032584
    Epoch 4000 MSE =  4.8032575
    Epoch 4100 MSE =  4.8032575
    Epoch 4200 MSE =  4.803256
    Epoch 4300 MSE =  4.803255
    Epoch 4400 MSE =  4.803256
    Epoch 4500 MSE =  4.803256
    Epoch 4600 MSE =  4.803253
    Epoch 4700 MSE =  4.8032565
    Epoch 4800 MSE =  4.803258
    Epoch 4900 MSE =  4.8032556
    Epoch 5000 MSE =  4.803256
    Epoch 5100 MSE =  4.8032537
    Epoch 5200 MSE =  4.8032565
    Epoch 5300 MSE =  4.803255
    Epoch 5400 MSE =  4.8032546
    Epoch 5500 MSE =  4.803254
    Epoch 5600 MSE =  4.8032537
    Epoch 5700 MSE =  4.8032517
    Epoch 5800 MSE =  4.8032527
    Epoch 5900 MSE =  4.8032537
    Epoch 6000 MSE =  4.803254
    Epoch 6100 MSE =  4.8032546
    Epoch 6200 MSE =  4.803255
    Epoch 6300 MSE =  4.8032546
    Epoch 6400 MSE =  4.803253
    Epoch 6500 MSE =  4.803253
    Epoch 6600 MSE =  4.803253
    Epoch 6700 MSE =  4.8032517
    Epoch 6800 MSE =  4.803252
    Epoch 6900 MSE =  4.8032517
    Epoch 7000 MSE =  4.803252
    Epoch 7100 MSE =  4.8032537
    Epoch 7200 MSE =  4.8032537
    Epoch 7300 MSE =  4.803253
    Epoch 7400 MSE =  4.803253
    Epoch 7500 MSE =  4.803253
    Epoch 7600 MSE =  4.803254
    Epoch 7700 MSE =  4.8032546
    Epoch 7800 MSE =  4.8032556
    Epoch 7900 MSE =  4.803256
    Epoch 8000 MSE =  4.8032565
    Epoch 8100 MSE =  4.8032565
    Epoch 8200 MSE =  4.8032565
    Epoch 8300 MSE =  4.8032556
    Epoch 8400 MSE =  4.8032565
    Epoch 8500 MSE =  4.8032575
    Epoch 8600 MSE =  4.8032565
    Epoch 8700 MSE =  4.803256
    Epoch 8800 MSE =  4.803256
    Epoch 8900 MSE =  4.8032556
    Epoch 9000 MSE =  4.803255
    Epoch 9100 MSE =  4.8032546
    Epoch 9200 MSE =  4.803254
    Epoch 9300 MSE =  4.8032546
    Epoch 9400 MSE =  4.8032546
    Epoch 9500 MSE =  4.803255
    Epoch 9600 MSE =  4.803255
    Epoch 9700 MSE =  4.803255
    Epoch 9800 MSE =  4.803255
    Epoch 9900 MSE =  4.803255
    [[ 0.43350863]
     [ 0.8296331 ]
     [ 0.11875448]
     [-0.26555073]
     [ 0.3057157 ]
     [-0.00450223]
     [-0.03932685]
     [-0.8998542 ]
     [-0.87051094]]
    View Code

    结果样例:

  • 相关阅读:
    使用多线程生产者消费者模式实现抓斗图
    selenium+chrome抓取淘宝搜索抓娃娃关键页面
    mysql必知必会
    mongoDB高级查询$type4array使用解析
    并发服务器几种实现方法总结
    python的面向对象和面向过程
    lazarus,synedit输入小键盘特殊符号的补丁
    Delphi中静态方法重载还是覆盖的讨论
    python全栈开发_day4_if,while和for
    python全栈开发_day3_数据类型,输入输出及运算符
  • 原文地址:https://www.cnblogs.com/yszd/p/9802435.html
Copyright © 2011-2022 走看看