zoukankan      html  css  js  c++  java
  • 梯度下降法(BGD)

    # coding-utf-8
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    
    class LR:
        def __init__(self, data, learning_rate=0.001, iter_max=10000, batch_size=2):
            self.data = data
            self.learning_rate = learning_rate
            self.iter_max = iter_max
            self.batch_size = batch_size
            self.process_data()
    
        # 数据标准化 
        def standard_scaler(self, data):
            data1 = data[:, :-1]
            mean = np.mean(data1, axis=0)
            std = np.std(data1, axis=0)
            data1 = (data1 - mean) / std
            return np.hstack((data1, data[:, -1:]))
    
        def process_data(self):
            data = np.array(self.data)
            data = self.standard_scaler(data)
            one = np.ones((data.shape[0], 1))
            self.data = np.hstack((one, data))
            self.m = self.data.shape[0]  # 样本总数量
            self.n = self.data.shape[1] - 1  # 特征总数量
    
        def model(self):
            return np.dot(self.data[:, :-1], self.theta)
    
        def mse(self, predict, y):
            return np.sum((predict - y) ** 2) / len(y)
    
        def cal_grad(self, predict, y):
            grad = np.ones(self.theta.shape)
            for i in range(len(grad)):
                grad[i] = np.mean((predict - y) * self.data[:, i])
            return grad
    
        @staticmethod
        def draw(list_data):
            plt.plot(range(len(list_data)), list_data)
            plt.show()
    
        def train(self):
            loss_list = []
            n = 1
            # 1、初始化theta 
            self.theta = np.ones((self.n, 1))
            predict = self.model()
            # 2、计算误差
            loss = self.mse(predict, self.data[:, -1:])
            loss_list.append(loss)
            while True:
                # 3、求梯度
                grad = self.cal_grad(predict, self.data[:, -1:])
                # 4、更新theta
                self.theta = self.theta - self.learning_rate * grad
                # 5、计算误差
                predict = self.model()
                loss = self.mse(predict, self.data[:, -1:])
                loss_list.append(loss)
                # if 判断停止条件 满足则跳出训练
                if n > self.iter_max:
                    break
                n += 1
            self.draw(loss_list)
    
    
    if __name__ == "__main__":
        data = pd.read_excel('C:/Users/jiedada/Desktop/python/回归/lr.xlsx')
        lr = LR(data)
        lr.train()
  • 相关阅读:
    mysql函数基本使用
    django form 组件源码解析
    jwt
    python数据类型 ——bytes 和 bytearray
    汇编基础四 --函数调用与堆栈平衡
    汇编基础之三 -- 汇编指令
    汇编基础之二 -- 寄存器和内存堆栈
    汇编基础之一 -- 位运算和四则运算的实现
    存储过程中的设置语句含义
    (转载)SQL去除回车符,换行符,空格和水平制表符
  • 原文地址:https://www.cnblogs.com/xiaoruirui/p/15736299.html
Copyright © 2011-2022 走看看