zoukankan      html  css  js  c++  java
  • 利用TensorFlow实现多元线性回归


    # -*- coding:utf-8 -*-
    import tensorflow as tf
    import numpy as np
    from sklearn import linear_model
    from sklearn import preprocessing
    # Read x and y
    x_data = np.loadtxt("ex3x.dat").astype(np.float32)
    y_data = np.loadtxt("ex3y.dat").astype(np.float32)
    # We evaluate the x and y by sklearn to get a sense of the coefficients.
    reg = linear_model.LinearRegression()
    reg.fit(x_data, y_data)
    print ("Coefficients of sklearn: K=%s, b=%f" % (reg.coef_, reg.intercept_))
    # Now we use tensorflow to get similar results.
    # Before we put the x_data into tensorflow, we need to standardize it
    # in order to achieve better performance in gradient descent;
    # If not standardized, the convergency speed could not be tolearated.
    # Reason:  If a feature has a variance that is orders of magnitude larger than others,
    # it might dominate the objective function
    # and make the estimator unable to learn from other features correctly as expected.
    scaler = preprocessing.StandardScaler().fit(x_data)
    print (scaler.mean_, scaler.scale_)
    x_data_standard = scaler.transform(x_data)
    W = tf.Variable(tf.zeros([2, 1]))
    b = tf.Variable(tf.zeros([1, 1]))
    y = tf.matmul(x_data_standard, W) + b
    loss = tf.reduce_mean(tf.square(y - y_data.reshape(-1, 1)))/2
    optimizer = tf.train.GradientDescentOptimizer(0.3)
    train = optimizer.minimize(loss)
    init = tf.initialize_all_variables()
    sess = tf.Session()
    for step in range(100):
        if step % 10 == 0:
            print (step, sess.run(W).flatten(), sess.run(b).flatten())
    print ("Coefficients of tensorflow (input should be standardized): K=%s, b=%s" % (sess.run(W).flatten(), sess.run(b).flatten()))
    print ("Coefficients of tensorflow (raw input): K=%s, b=%s" % (sess.run(W).flatten() / scaler.scale_, sess.run(b).flatten() - np.dot(scaler.mean_ / scaler.scale_, sess.run(W))))


  • 相关阅读:
    urql 高度可自定义&&多功能的react graphql client
    使用vault pki 为nginx 生成tls 证书文件
    使用vault pki engine 方便的管理证书
    使用terraform 生成自签名证书
    Kapitan 通用terraform&& kubernetes 配置管理工具
    sqler 集成 terraform v0.12 生成资源部署文件
    检查cgroup v2 是否安装
    centos 较新版本kernel安装方法
    tbls ci 友好的数据库文档化工具
    graphql-query-rewriter 无缝处理graphql 变更
  • 原文地址:https://www.cnblogs.com/hunttown/p/6844672.html
Copyright © 2011-2022 走看看