zoukankan      html  css  js  c++  java
  • 从头开始使用梯度下降优化在Python中实现多元线性回归(代码)

    import matplotlib.pyplot as plt
    import numpy as np

    def hypothesis(theta, X, n):
    h = np.ones((X.shape[0],1))
    theta = theta.reshape(1,n+1)
    for i in range(0,X.shape[0]):
    h[i] = float(np.matmul(theta, X[i]))
    h = h.reshape(X.shape[0])
    return h


    def BGD(theta, alpha, num_iters, h, X, y, n):
    cost = np.ones(num_iters)
    for i in range(0,num_iters):
    theta[0] = theta[0] - (alpha/X.shape[0]) * sum(h - y)
    for j in range(1,n+1):
    theta[j] = theta[j] - (alpha/X.shape[0]) * sum((h-y) * X.transpose()[j])
    h = hypothesis(theta, X, n)
    cost[i] = (1/X.shape[0]) * 0.5 * sum(np.square(h - y))
    theta = theta.reshape(1,n+1)
    return theta, cost

    def linear_regression(X, y, alpha, num_iters):
    n = X.shape[1]
    one_column = np.ones((X.shape[0],1))
    X = np.concatenate((one_column, X), axis = 1)
    # initializing the parameter vector...
    theta = np.zeros(n+1)
    # hypothesis calculation....
    h = hypothesis(theta, X, n)
    # returning the optimized parameters by Gradient Descent...
    theta, cost = BGD(theta,alpha,num_iters,h,X,y,n)
    return theta, cost


    data = np.loadtxt('data1.txt', delimiter=',')
    X_train = data[:,:7] #feature set
    y_train = data[:,8] #label set


    mean = np.ones(X_train.shape[1])
    std = np.ones(X_train.shape[1])
    for i in range(0, X_train.shape[1]):
    mean[i] = np.mean(X_train.transpose()[i])
    std[i] = np.std(X_train.transpose()[i])
    for j in range(0,X_train.shape[0]):
    X_train[j][i] = (X_train[j][i] - mean[i])/std[i]


    theta, cost = linear_regression(X_train, y_train, 0.0001, 30000)


    cost = list(cost)
    n_iterations = [x for x in range(1,30001)]
    plt.plot(n_iterations, cost)
    plt.xlabel('No. of iterations')
    plt.ylabel('Cost')
  • 相关阅读:
    回调函数(C语言)
    main函数的参数(一)
    术语,概念
    [LeetCode] Invert Binary Tree
    关于overload和override
    第一个只出现一次的字符
    Manacher算法----最长回文子串
    C++对象模型
    回文判断
    字符串转换成整数
  • 原文地址:https://www.cnblogs.com/dr-xsh/p/13211927.html
Copyright © 2011-2022 走看看