1 # -*- coding: utf-8 -*- 2 """ 3 Created on Mon Oct 15 17:38:39 2018 4 5 @author: zhen 6 """ 7 8 import tensorflow as tf 9 import numpy as np 10 from sklearn.datasets import fetch_california_housing 11 from sklearn.preprocessing import StandardScaler 12 13 n_epochs = 10000 14 learning_rate = 0.01 15 16 housing = fetch_california_housing(data_home="C:/Users/zhen/.spyder-py3/data", download_if_missing=True) 17 m, n = housing.data.shape 18 housing_data_plus_bias = np.c_[np.ones((m, 1)), housing.data] 19 # 归一化 20 scaler= StandardScaler().fit(housing_data_plus_bias) 21 scaled_housing_data_plus_bias = scaler.transform(housing_data_plus_bias) 22 # 创建常量 23 x = tf.constant(scaled_housing_data_plus_bias, dtype=tf.float32, name='x') 24 y = tf.constant(housing.target.reshape(-1, 1), dtype=tf.float32, name='y') 25 # 创建随机数 26 theta = tf.Variable(tf.random_uniform([n + 1, 1], -1.0, 1.0), name='theta') 27 # 矩阵乘 28 y_pred = tf.matmul(x, theta, name="predictions") 29 30 error = y_pred - y 31 # 求平均值 32 mse = tf.reduce_mean(tf.square(error), name="mse") 33 """ 34 # 求梯度 35 gradients = tf.gradients(mse, [theta])[0] 36 # 赋值 37 training_op = tf.assign(theta, theta - learning_rate * gradients) 38 """ 39 # 梯度下降 40 optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) 41 training_op = optimizer.minimize(mse) 42 43 init = tf.global_variables_initializer() 44 45 with tf.Session() as sess: 46 sess.run(init) 47 48 for epoch in range(n_epochs): 49 if epoch % 100 == 0: 50 print("Epoch", epoch, "MSE = ", mse.eval()) 51 sess.run(training_op) 52 53 best_theta = theta.eval() 54 print(best_theta)
结果:
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
Epoch 0 MSE = 9.128207 Epoch 100 MSE = 4.893214 Epoch 200 MSE = 4.8329406 Epoch 300 MSE = 4.824335 Epoch 400 MSE = 4.8187895 Epoch 500 MSE = 4.814753 Epoch 600 MSE = 4.811796 Epoch 700 MSE = 4.8096204 Epoch 800 MSE = 4.808017 Epoch 900 MSE = 4.806835 Epoch 1000 MSE = 4.805955 Epoch 1100 MSE = 4.805301 Epoch 1200 MSE = 4.8048124 Epoch 1300 MSE = 4.804449 Epoch 1400 MSE = 4.804172 Epoch 1500 MSE = 4.803962 Epoch 1600 MSE = 4.8038034 Epoch 1700 MSE = 4.803686 Epoch 1800 MSE = 4.8035927 Epoch 1900 MSE = 4.80352 Epoch 2000 MSE = 4.8034678 Epoch 2100 MSE = 4.803425 Epoch 2200 MSE = 4.8033857 Epoch 2300 MSE = 4.803362 Epoch 2400 MSE = 4.803341 Epoch 2500 MSE = 4.8033247 Epoch 2600 MSE = 4.80331 Epoch 2700 MSE = 4.8033013 Epoch 2800 MSE = 4.8032923 Epoch 2900 MSE = 4.8032856 Epoch 3000 MSE = 4.8032804 Epoch 3100 MSE = 4.803273 Epoch 3200 MSE = 4.803271 Epoch 3300 MSE = 4.8032694 Epoch 3400 MSE = 4.803267 Epoch 3500 MSE = 4.8032637 Epoch 3600 MSE = 4.8032603 Epoch 3700 MSE = 4.803259 Epoch 3800 MSE = 4.803259 Epoch 3900 MSE = 4.8032584 Epoch 4000 MSE = 4.8032575 Epoch 4100 MSE = 4.8032575 Epoch 4200 MSE = 4.803256 Epoch 4300 MSE = 4.803255 Epoch 4400 MSE = 4.803256 Epoch 4500 MSE = 4.803256 Epoch 4600 MSE = 4.803253 Epoch 4700 MSE = 4.8032565 Epoch 4800 MSE = 4.803258 Epoch 4900 MSE = 4.8032556 Epoch 5000 MSE = 4.803256 Epoch 5100 MSE = 4.8032537 Epoch 5200 MSE = 4.8032565 Epoch 5300 MSE = 4.803255 Epoch 5400 MSE = 4.8032546 Epoch 5500 MSE = 4.803254 Epoch 5600 MSE = 4.8032537 Epoch 5700 MSE = 4.8032517 Epoch 5800 MSE = 4.8032527 Epoch 5900 MSE = 4.8032537 Epoch 6000 MSE = 4.803254 Epoch 6100 MSE = 4.8032546 Epoch 6200 MSE = 4.803255 Epoch 6300 MSE = 4.8032546 Epoch 6400 MSE = 4.803253 Epoch 6500 MSE = 4.803253 Epoch 6600 MSE = 4.803253 Epoch 6700 MSE = 4.8032517 Epoch 6800 MSE = 4.803252 Epoch 6900 MSE = 4.8032517 Epoch 7000 MSE = 4.803252 Epoch 7100 MSE = 4.8032537 Epoch 7200 MSE = 4.8032537 Epoch 7300 MSE = 4.803253 Epoch 7400 MSE = 4.803253 Epoch 7500 MSE = 4.803253 Epoch 7600 MSE = 4.803254 Epoch 7700 MSE = 4.8032546 Epoch 7800 MSE = 4.8032556 Epoch 7900 MSE = 4.803256 Epoch 8000 MSE = 4.8032565 Epoch 8100 MSE = 4.8032565 Epoch 8200 MSE = 4.8032565 Epoch 8300 MSE = 4.8032556 Epoch 8400 MSE = 4.8032565 Epoch 8500 MSE = 4.8032575 Epoch 8600 MSE = 4.8032565 Epoch 8700 MSE = 4.803256 Epoch 8800 MSE = 4.803256 Epoch 8900 MSE = 4.8032556 Epoch 9000 MSE = 4.803255 Epoch 9100 MSE = 4.8032546 Epoch 9200 MSE = 4.803254 Epoch 9300 MSE = 4.8032546 Epoch 9400 MSE = 4.8032546 Epoch 9500 MSE = 4.803255 Epoch 9600 MSE = 4.803255 Epoch 9700 MSE = 4.803255 Epoch 9800 MSE = 4.803255 Epoch 9900 MSE = 4.803255 [[ 0.43350863] [ 0.8296331 ] [ 0.11875448] [-0.26555073] [ 0.3057157 ] [-0.00450223] [-0.03932685] [-0.8998542 ] [-0.87051094]]
结果样例: