zoukankan      html  css  js  c++  java
  • 从头开始使用梯度下降优化在Python中实现多元线性回归(代码)

    import matplotlib.pyplot as plt
    import numpy as np

    def hypothesis(theta, X, n):
    h = np.ones((X.shape[0],1))
    theta = theta.reshape(1,n+1)
    for i in range(0,X.shape[0]):
    h[i] = float(np.matmul(theta, X[i]))
    h = h.reshape(X.shape[0])
    return h


    def BGD(theta, alpha, num_iters, h, X, y, n):
    cost = np.ones(num_iters)
    for i in range(0,num_iters):
    theta[0] = theta[0] - (alpha/X.shape[0]) * sum(h - y)
    for j in range(1,n+1):
    theta[j] = theta[j] - (alpha/X.shape[0]) * sum((h-y) * X.transpose()[j])
    h = hypothesis(theta, X, n)
    cost[i] = (1/X.shape[0]) * 0.5 * sum(np.square(h - y))
    theta = theta.reshape(1,n+1)
    return theta, cost

    def linear_regression(X, y, alpha, num_iters):
    n = X.shape[1]
    one_column = np.ones((X.shape[0],1))
    X = np.concatenate((one_column, X), axis = 1)
    # initializing the parameter vector...
    theta = np.zeros(n+1)
    # hypothesis calculation....
    h = hypothesis(theta, X, n)
    # returning the optimized parameters by Gradient Descent...
    theta, cost = BGD(theta,alpha,num_iters,h,X,y,n)
    return theta, cost


    data = np.loadtxt('data1.txt', delimiter=',')
    X_train = data[:,:7] #feature set
    y_train = data[:,8] #label set


    mean = np.ones(X_train.shape[1])
    std = np.ones(X_train.shape[1])
    for i in range(0, X_train.shape[1]):
    mean[i] = np.mean(X_train.transpose()[i])
    std[i] = np.std(X_train.transpose()[i])
    for j in range(0,X_train.shape[0]):
    X_train[j][i] = (X_train[j][i] - mean[i])/std[i]


    theta, cost = linear_regression(X_train, y_train, 0.0001, 30000)


    cost = list(cost)
    n_iterations = [x for x in range(1,30001)]
    plt.plot(n_iterations, cost)
    plt.xlabel('No. of iterations')
    plt.ylabel('Cost')
  • 相关阅读:
    数据结构与算法(3-4)--矩阵的压缩存储
    数据结构与算法(3-3)--队列的应用
    数据结构与算法(3-2)--栈的应用
    数据结构与算法(3-1)--栈和队列
    数据结构与算法(2)--线性表(数组和链表)
    数据结构与算法(1)--时间及空间复杂度
    python变量与地址的关系
    python高级(03)--socket编程
    python高级(02)--生成器和迭代器
    python处理http接口请求
  • 原文地址:https://www.cnblogs.com/dr-xsh/p/13211927.html
Copyright © 2011-2022 走看看