zoukankan      html  css  js  c++  java
  • ML.NET教程之出租车车费预测(回归问题)

    理解问题

    出租车的车费不仅与距离有关,还涉及乘客数量,是否使用信用卡等因素(这是的出租车是指纽约市的)。所以并不是一个简单的一元方程问题。

    准备数据

    建立一控制台应用程序工程,新建Data文件夹,在其目录下添加taxi-fare-train.csvtaxi-fare-test.csv文件,不要忘了把它们的Copy to Output Directory属性改为Copy if newer。之后,添加Microsoft.ML类库包。

    加载数据

    新建MLContext对象,及创建TextLoader对象。TextLoader对象可用于从文件中读取数据。

    MLContext mlContext = new MLContext(seed: 0);
    
    _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
    {
        Separator = ",",
        HasHeader = true,
        Column = new[]
        {
            new TextLoader.Column("VendorId", DataKind.Text, 0),
            new TextLoader.Column("RateCode", DataKind.Text, 1),
            new TextLoader.Column("PassengerCount", DataKind.R4, 2),
            new TextLoader.Column("TripTime", DataKind.R4, 3),
            new TextLoader.Column("TripDistance", DataKind.R4, 4),
            new TextLoader.Column("PaymentType", DataKind.Text, 5),
            new TextLoader.Column("FareAmount", DataKind.R4, 6)
        }
    });
    

    提取特征

    数据集文件里共有七列,前六列做为特征数据,最后一列是标记数据。

    public class TaxiTrip
    {
        [Column("0")]
        public string VendorId;
    
        [Column("1")]
        public string RateCode;
    
        [Column("2")]
        public float PassengerCount;
    
        [Column("3")]
        public float TripTime;
    
        [Column("4")]
        public float TripDistance;
    
        [Column("5")]
        public string PaymentType;
    
        [Column("6")]
        public float FareAmount;
    }
    
    public class TaxiTripFarePrediction
    {
        [ColumnName("Score")]
        public float FareAmount;
    }
    

    训练模型

    首先读取训练数据集,其次建立管道。管道中第一步是把FareAmount列复制到Label列,做为标记数据。第二步,通过OneHotEncoding方式将VendorIdRateCodePaymentType三个字符串类型列转换成数值类型列。第三步,合并六个数据列为一个特征数据列。最后一步,选择FastTreeRegressionTrainer算法做为训练方法。
    完成管道后,开始训练模型。

    IDataView dataView = _textLoader.Read(dataPath);
    var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
        .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
        .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
        .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
        .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
        .Append(mlContext.Regression.Trainers.FastTree());
    var model = pipeline.Fit(dataView);
    

    评估模型

    这里要使用测试数据集,并用回归问题的Evaluate方法进行评估。

    IDataView dataView = _textLoader.Read(_testDataPath);
    var predictions = model.Transform(dataView);
    var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
    Console.WriteLine();
    Console.WriteLine($"*************************************************");
    Console.WriteLine($"*       Model quality metrics evaluation         ");
    Console.WriteLine($"*------------------------------------------------");
    Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
    Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");
    

    保存模型

    完成训练的模型可以被保存为zip文件以备之后使用。

    using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
        mlContext.Model.Save(model, fileStream);
    

    使用模型

    首先加载已经保存的模型。接着建立预测函数对象,TaxiTrip为函数的输入类型,TaxiTripFarePrediction为输出类型。之后执行预测方法,传入待测数据。

    ITransformer loadedModel;
    using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
    {
        loadedModel = mlContext.Model.Load(stream);
    }
    
    var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);
    
    var taxiTripSample = new TaxiTrip()
    {
        VendorId = "VTS",
        RateCode = "1",
        PassengerCount = 1,
        TripTime = 1140,
        TripDistance = 3.75f,
        PaymentType = "CRD",
        FareAmount = 0 // To predict. Actual/Observed = 15.5
    };
    
    var prediction = predictionFunction.Predict(taxiTripSample);
    
    Console.WriteLine($"**********************************************************************");
    Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
    Console.WriteLine($"**********************************************************************");
    

    完整示例代码

    using Microsoft.ML;
    using Microsoft.ML.Core.Data;
    using Microsoft.ML.Runtime.Data;
    using System;
    using System.IO;
    
    namespace TexiFarePredictor
    {
        class Program
        {
            static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-train.csv");
            static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-test.csv");
            static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip");
            static TextLoader _textLoader;
    
            static void Main(string[] args)
            {
                MLContext mlContext = new MLContext(seed: 0);
    
                _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
                {
                    Separator = ",",
                    HasHeader = true,
                    Column = new[]
                    {
                        new TextLoader.Column("VendorId", DataKind.Text, 0),
                        new TextLoader.Column("RateCode", DataKind.Text, 1),
                        new TextLoader.Column("PassengerCount", DataKind.R4, 2),
                        new TextLoader.Column("TripTime", DataKind.R4, 3),
                        new TextLoader.Column("TripDistance", DataKind.R4, 4),
                        new TextLoader.Column("PaymentType", DataKind.Text, 5),
                        new TextLoader.Column("FareAmount", DataKind.R4, 6)
                    }
                });
    
                var model = Train(mlContext, _trainDataPath);
    
                Evaluate(mlContext, model);
    
                TestSinglePrediction(mlContext);
    
                Console.Read();
            }
    
            public static ITransformer Train(MLContext mlContext, string dataPath)
            {
                IDataView dataView = _textLoader.Read(dataPath);
                var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
                    .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
                    .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
                    .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
                    .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
                    .Append(mlContext.Regression.Trainers.FastTree());
                var model = pipeline.Fit(dataView);
                SaveModelAsFile(mlContext, model);
                return model;
            }
    
            private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
            {
                using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
                    mlContext.Model.Save(model, fileStream);
            }
    
            private static void Evaluate(MLContext mlContext, ITransformer model)
            {
                IDataView dataView = _textLoader.Read(_testDataPath);
                var predictions = model.Transform(dataView);
                var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
                Console.WriteLine();
                Console.WriteLine($"*************************************************");
                Console.WriteLine($"*       Model quality metrics evaluation         ");
                Console.WriteLine($"*------------------------------------------------");
                Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
                Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");
            }
    
            private static void TestSinglePrediction(MLContext mlContext)
            {
                ITransformer loadedModel;
                using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
                {
                    loadedModel = mlContext.Model.Load(stream);
                }
    
                var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);
    
                var taxiTripSample = new TaxiTrip()
                {
                    VendorId = "VTS",
                    RateCode = "1",
                    PassengerCount = 1,
                    TripTime = 1140,
                    TripDistance = 3.75f,
                    PaymentType = "CRD",
                    FareAmount = 0 // To predict. Actual/Observed = 15.5
                };
    
                var prediction = predictionFunction.Predict(taxiTripSample);
    
                Console.WriteLine($"**********************************************************************");
                Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
                Console.WriteLine($"**********************************************************************");
            }
        }
    }
    

    程序运行后显示的结果:

    *************************************************
    *       Model quality metrics evaluation
    *------------------------------------------------
    *       R2 Score:      0.92
    *       RMS loss:      2.81
    **********************************************************************
    Predicted fare: 15.7855, actual fare: 15.5
    **********************************************************************
    

    最后的预测结果还是比较符合实际数值的。

  • 相关阅读:
    文件夹隐藏加密
    hive日期函数
    MySql创建、查看、删除索引
    Vulnhub实战靶场:DC-3
    Vulnhub实战靶场:DC-2
    Vulnhub实战靶场:DC-1
    Vulnhub靶场练习:CHERRY: 1
    Vulnhub靶场练习:Chili:1
    Vulnhub靶场练习:Breach 3.0
    Vulnhub实战靶场练习:Breach 2.0
  • 原文地址:https://www.cnblogs.com/kenwoo/p/10171481.html
Copyright © 2011-2022 走看看