zoukankan      html  css  js  c++  java
  • 11.Lasso 岭回归的调用

    #!/usr/bin/python
    # -*- coding:utf-8 -*-
    
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import Lasso, Ridge
    from sklearn.model_selection import GridSearchCV
    
    
    if __name__ == "__main__":
        # pandas读入
        data = pd.read_csv(r'C:/8.Advertising.csv')    # TV、Radio、Newspaper、Sales
        x = data[['TV', 'Radio', 'Newspaper']]
        # x = data[['TV', 'Radio']]
        y = data['Sales']
        print(x)
        print(y)
    
        x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1)
        # print x_train, y_train
        #Lasso
        model = Lasso()
        #岭回归
        # model = Ridge()
    
        alpha_can = np.logspace(-3, 2, 10)
        lasso_model = GridSearchCV(model, param_grid={'alpha': alpha_can}, cv=5)
        lasso_model.fit(x, y)
        print('验证参数:
    ', lasso_model.best_params_)
    
        y_hat = lasso_model.predict(np.array(x_test))
        mse = np.average((y_hat - np.array(y_test)) ** 2)  # Mean Squared Error
        rmse = np.sqrt(mse)  # Root Mean Squared Error
        print(mse, rmse)
    
        t = np.arange(len(x_test))
        plt.plot(t, y_test, 'r-', linewidth=2, label='Test')
        plt.plot(t, y_hat, 'g-', linewidth=2, label='Predict')
        plt.legend(loc='upper right')
        plt.grid()
        plt.show()
  • 相关阅读:
    各地电信运营商插广告赚钱,北京联通也不甘落后
    也谈Server Limit DOS的解决方案
    Still Believe
    无奈小虫何
    好朋有也有类别
    无为而治
    青鸟随想
    落寞时分
    网站开发学习路线和资料
    C++实例 添加快捷键表
  • 原文地址:https://www.cnblogs.com/xiaochi/p/11264825.html
Copyright © 2011-2022 走看看