zoukankan      html  css  js  c++  java
  • 11.Lasso 岭回归的调用

    #!/usr/bin/python
    # -*- coding:utf-8 -*-
    
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import Lasso, Ridge
    from sklearn.model_selection import GridSearchCV
    
    
    if __name__ == "__main__":
        # pandas读入
        data = pd.read_csv(r'C:/8.Advertising.csv')    # TV、Radio、Newspaper、Sales
        x = data[['TV', 'Radio', 'Newspaper']]
        # x = data[['TV', 'Radio']]
        y = data['Sales']
        print(x)
        print(y)
    
        x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1)
        # print x_train, y_train
        #Lasso
        model = Lasso()
        #岭回归
        # model = Ridge()
    
        alpha_can = np.logspace(-3, 2, 10)
        lasso_model = GridSearchCV(model, param_grid={'alpha': alpha_can}, cv=5)
        lasso_model.fit(x, y)
        print('验证参数:
    ', lasso_model.best_params_)
    
        y_hat = lasso_model.predict(np.array(x_test))
        mse = np.average((y_hat - np.array(y_test)) ** 2)  # Mean Squared Error
        rmse = np.sqrt(mse)  # Root Mean Squared Error
        print(mse, rmse)
    
        t = np.arange(len(x_test))
        plt.plot(t, y_test, 'r-', linewidth=2, label='Test')
        plt.plot(t, y_hat, 'g-', linewidth=2, label='Predict')
        plt.legend(loc='upper right')
        plt.grid()
        plt.show()
  • 相关阅读:
    es进行聚合操作时提示Fielddata is disabled on text fields by default
    es基本操作
    maven项目修改项目名
    Linux命令整理
    CentOS 安装git
    Linux命令
    纵表转横表
    Row_Number() over()
    事件冒泡/捕获
    js获取参数 解决乱码
  • 原文地址:https://www.cnblogs.com/xiaochi/p/11264825.html
Copyright © 2011-2022 走看看