zoukankan      html  css  js  c++  java
  • 利用TensorFlow实现多元线性回归

    利用TensorFlow实现多元线性回归,代码如下:

    # -*- coding:utf-8 -*-
    import tensorflow as tf
    import numpy as np
    from sklearn import linear_model
    from sklearn import preprocessing
    
    # Read x and y
    x_data = np.loadtxt("ex3x.dat").astype(np.float32)
    y_data = np.loadtxt("ex3y.dat").astype(np.float32)
    
    # We evaluate the x and y by sklearn to get a sense of the coefficients.
    reg = linear_model.LinearRegression()
    reg.fit(x_data, y_data)
    print ("Coefficients of sklearn: K=%s, b=%f" % (reg.coef_, reg.intercept_))
    
    # Now we use tensorflow to get similar results.
    # Before we put the x_data into tensorflow, we need to standardize it
    # in order to achieve better performance in gradient descent;
    # If not standardized, the convergency speed could not be tolearated.
    # Reason:  If a feature has a variance that is orders of magnitude larger than others,
    # it might dominate the objective function
    # and make the estimator unable to learn from other features correctly as expected.
    scaler = preprocessing.StandardScaler().fit(x_data)
    print (scaler.mean_, scaler.scale_)
    x_data_standard = scaler.transform(x_data)
    
    W = tf.Variable(tf.zeros([2, 1]))
    b = tf.Variable(tf.zeros([1, 1]))
    y = tf.matmul(x_data_standard, W) + b
    
    loss = tf.reduce_mean(tf.square(y - y_data.reshape(-1, 1)))/2
    optimizer = tf.train.GradientDescentOptimizer(0.3)
    train = optimizer.minimize(loss)
    
    init = tf.initialize_all_variables()
    
    sess = tf.Session()
    sess.run(init)
    for step in range(100):
        sess.run(train)
        if step % 10 == 0:
            print (step, sess.run(W).flatten(), sess.run(b).flatten())
    
    print ("Coefficients of tensorflow (input should be standardized): K=%s, b=%s" % (sess.run(W).flatten(), sess.run(b).flatten()))
    print ("Coefficients of tensorflow (raw input): K=%s, b=%s" % (sess.run(W).flatten() / scaler.scale_, sess.run(b).flatten() - np.dot(scaler.mean_ / scaler.scale_, sess.run(W))))

    数据集下载:下载地址

  • 相关阅读:
    Web标准:五、超链接伪类
    Spring Security(16)——基于表达式的权限控制
    Spring Security(15)——权限鉴定结构
    Spring Security(14)——权限鉴定基础
    Spring Security(13)——session管理
    Spring Security(12)——Remember-Me功能
    Spring Security(11)——匿名认证
    Spring Security(10)——退出登录logout
    Spring Security(09)——Filter
    Spring Security(08)——intercept-url配置
  • 原文地址:https://www.cnblogs.com/hunttown/p/6844672.html
Copyright © 2011-2022 走看看