zoukankan      html  css  js  c++  java
  • 逻辑回归

    from math import exp
    import numpy as np
    import pandas as pd
    from sklearn.datasets import load_iris
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split

    定义LR回归模型

    class LogisticReression:
    def init(self,max_iter=200,learning_rate=0.01):
    self.max_iter = max_iter
    self.learing_rate = learning_rate
    def sigmoid(self,x):
    return 1/(1 + exp(-x) )
    def data_matrix(self,X):
    data_mat = []
    print("X:",X)
    for d in X:
    data_mat.append([1.0,d])
    print("data_mat:",data_mat)
    return data_mat
    #训练
    def train(self,X,y):
    #label = np.mat(y)
    data_mat = self.data_matrix(X)#m
    n
    self.weights = np.zeros((len(data_mat[0]),1),dtype=np.float32)
    print("weight:",len(self.weights))
    for iter_ in range(self.max_iter):
    for i in range(len(X)):#这里是每个样本更新一次权重,其目的是避免样本过多造成的计算量大
    result = self.sigmoid(np.dot(data_mat[i],self.weights))
    error = y[i] - result
    self.weights += self.learing_rate * error * np.transpose([data_mat[i]])
    print("LR模型学习率={},最大迭代次数={}".format(self.learing_rate,self.max_iter))
    #准确率
    def accuracy(self,X_test,y_test):
    right = 0
    X_test = self.data_matrix(X_test)
    for x,y in zip(X_test,y_test):
    result = np.dot(x,self.weights)
    if (result > 0 and y == 1) or (result < 0 and y == 0):
    right += 1

    构建数据

    def create_data():
    iris = load_iris()
    df = pd.DataFrame(iris.data,columns = iris.feature_names)
    df['label'] = iris.target
    df.columns = ['sepal length','sepal width','petal length','petal width','label']
    data = np.array(df.iloc[:100,[0,1,-1]])
    print("data:",data)
    return data[:,:2],data[:,-1]
    X,y = create_data()
    X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=0.3)

    训练

    LR = LogisticReression()
    LR.train(X_train, y_train)

    计算精度

    LR.accuracy(X_test, y_test)

    效果展示

    x_ponits = np.arange(3, 9)
    y_ = -(LR.weights[1]*x_ponits + LR.weights[0])/LR.weights[2]
    plt.plot(x_ponits, y_)

    绘制图

    plt.scatter(X[:50,0],X[:50,1], label='0')
    plt.scatter(X[50:,0],X[50:,1], label='1')
    plt.legend()
    plt.show()

  • 相关阅读:
    第一周例行报告
    2018091-2 博客作业
    jQuery $.post $.ajax用法
    HTML ul、li 属性介绍
    PHP日期格式转时间戳
    php字符串与字符替换函数
    Linux内核参数
    ifconfig-dropped
    mysql_load_data及权限管理
    加快mysql导入导出速度
  • 原文地址:https://www.cnblogs.com/131415-520/p/12323435.html
Copyright © 2011-2022 走看看