zoukankan      html  css  js  c++  java
  • 使用ML.NET预测纽约出租车费

    有了上一篇《.NET Core玩转机器学习》打基础,这一次我们以纽约出租车费的预测做为新的场景案例,来体验一下回归模型。

    场景概述


    我们的目标是预测纽约的出租车费,乍一看似乎仅仅取决于行程的距离和时长,然而纽约的出租车供应商对其他因素,如额外的乘客数、信用卡而不是现金支付等,会综合考虑而收取不同数额的费用。纽约市官方给出了一份样本数据

    确定策略


    为了能够预测出租车费,我们选择通过机器学习建立一个回归模型。使用官方提供的真实数据进行拟合,在训练模型的过程中确定真正能影响出租车费的决定性特征。在获得模型后,对模型进行评估验证,如果偏差在接受的范围内,就以这个模型来对新的数据进行预测。

    解决方案


    • 创建项目

      看过上一篇文章的读者,就比较轻车熟路了,推荐使用Visual Studio 2017创建一个.NET Core的控制台应用程序项目,命名为TaxiFarePrediction。使用NuGet包管理工具添加对Microsoft.ML的引用。



    • 准备数据集

      下载训练数据集taxi-fare-train.csv和验证数据集taxi-fare-test.csv,数据集的内容类似为:
      vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
      VTS,1,1,1140,3.75,CRD,15.5
      VTS,1,1,480,2.72,CRD,10.0
      VTS,1,1,1680,7.8,CSH,26.5
      VTS,1,1,600,4.73,CSH,14.5
      VTS,1,1,600,2.18,CRD,9.5
      ...

      对字段简单说明一下:

      字段名 含义 说明
      vendor_id 供应商编号 特征值
      rate_code 比率码 特征值
      passenger_count 乘客人数 特征值
      trip_time_in_secs 行程时长 特征值
      trip_distance 行程距离 特征值
      payment_type 支付类型 特征值
      fare_amount 费用 目标值

      在项目中添加一个Data目录,将两份数据集复制到该目录下,对文件属性设置“复制到输出目录”。




    • 定义数据类型和路径

      首先声明相关的包引用。

      using System;
      using Microsoft.ML.Models;
      using Microsoft.ML.Runtime;
      using Microsoft.ML.Runtime.Api;
      using Microsoft.ML.Trainers;
      using Microsoft.ML.Transforms;
      using System.Collections.Generic;
      using System.Linq;
      using Microsoft.ML;

      在Main函数的上方定义一些使用到的常量。

      const string DataPath = @".Data	axi-fare-train.csv";
      const string TestDataPath = @".Data	axi-fare-test.csv";
      const string ModelPath = @".ModelsModel.zip";
      const string ModelDirectory = @".Models";

      接下来定义一些使用到的数据类型,以及和数据集中每一行的位置对应关系。

      public class TaxiTrip
      {
          [Column(ordinal: "0")]
          public string vendor_id;
          [Column(ordinal: "1")]
          public string rate_code;
          [Column(ordinal: "2")]
          public float passenger_count;
          [Column(ordinal: "3")]
          public float trip_time_in_secs;
          [Column(ordinal: "4")]
          public float trip_distance;
          [Column(ordinal: "5")]
          public string payment_type;
          [Column(ordinal: "6")]
          public float fare_amount;
      }
      
      public class TaxiTripFarePrediction
      {
          [ColumnName("Score")]
          public float fare_amount;
      }
      
      static class TestTrips
      {
          internal static readonly TaxiTrip Trip1 = new TaxiTrip
          {
              vendor_id = "VTS",
              rate_code = "1",
              passenger_count = 1,
              trip_distance = 10.33f,
              payment_type = "CSH",
              fare_amount = 0 // predict it. actual = 29.5
          };
      }
    • 创建处理过程

      创建一个Train方法,定义对数据集的处理过程,随后声明一个模型接收训练后的结果,在返回前把模型保存到指定的位置,以便以后直接取出来使用不需要再重新训练。
      public static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train()
      {
          var pipeline = new LearningPipeline();
      
          pipeline.Add(new TextLoader<TaxiTrip>(DataPath, useHeader: true, separator: ","));
          pipeline.Add(new ColumnCopier(("fare_amount", "Label")));
          pipeline.Add(new CategoricalOneHotVectorizer("vendor_id",
                                              "rate_code",
                                              "payment_type"));
          pipeline.Add(new ColumnConcatenator("Features",
                                              "vendor_id",
                                              "rate_code",
                                              "passenger_count",
                                              "trip_distance",
                                              "payment_type"));
          pipeline.Add(new FastTreeRegressor());
          PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();
          if (!Directory.Exists(ModelDirectory))
          {
              Directory.CreateDirectory(ModelDirectory);
          }
          await model.WriteAsync(ModelPath);
          return model;
      }
    • 评估验证模型

      创建一个Evaluate方法,对训练后的模型进行验证评估。
      public static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
      {
          var testData = new TextLoader<TaxiTrip>(TestDataPath, useHeader: true, separator: ",");
          var evaluator = new RegressionEvaluator();
          RegressionMetrics metrics = evaluator.Evaluate(model, testData);
          // Rms should be around 2.795276
          Console.WriteLine("Rms=" + metrics.Rms);
          Console.WriteLine("RSquared = " + metrics.RSquared);
      }
    • 预测新数据

      定义一个被用于预测的新数据,对于各个特征进行恰当地赋值。
      static class TestTrips
      {
          internal static readonly TaxiTrip Trip1 = new TaxiTrip
          {
              vendor_id = "VTS",
              rate_code = "1",
              passenger_count = 1,
              trip_distance = 10.33f,
              payment_type = "CSH",
              fare_amount = 0 // predict it. actual = 29.5
          };
      }

      预测的方法很简单,prediction即预测的结果,从中打印出预测的费用和真实费用。

      var prediction = model.Predict(TestTrips.Trip1);
      
      Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.fare_amount);
    • 运行结果



    到此我们完成了所有的步骤,关于这些代码的详细说明,可以参看《Tutorial: Use ML.NET to Predict New York Taxi Fares (Regression)》,只是要注意该文中的部分代码有误,由于使用到了C# 7.1的语法特性,本文的代码是经过了修正的。完整的代码如下:

    using System;
    using Microsoft.ML.Models;
    using Microsoft.ML.Runtime;
    using Microsoft.ML.Runtime.Api;
    using Microsoft.ML.Trainers;
    using Microsoft.ML.Transforms;
    using System.Collections.Generic;
    using System.Linq;
    using Microsoft.ML;
    using System.Threading.Tasks;
    using System.IO;
    
    namespace TaxiFarePrediction
    {
        class Program
        {
            const string DataPath = @".Data	axi-fare-train.csv";
            const string TestDataPath = @".Data	axi-fare-test.csv";
            const string ModelPath = @".ModelsModel.zip";
            const string ModelDirectory = @".Models";
    
            public class TaxiTrip
            {
                [Column(ordinal: "0")]
                public string vendor_id;
                [Column(ordinal: "1")]
                public string rate_code;
                [Column(ordinal: "2")]
                public float passenger_count;
                [Column(ordinal: "3")]
                public float trip_time_in_secs;
                [Column(ordinal: "4")]
                public float trip_distance;
                [Column(ordinal: "5")]
                public string payment_type;
                [Column(ordinal: "6")]
                public float fare_amount;
            }
    
            public class TaxiTripFarePrediction
            {
                [ColumnName("Score")]
                public float fare_amount;
            }
    
            static class TestTrips
            {
                internal static readonly TaxiTrip Trip1 = new TaxiTrip
                {
                    vendor_id = "VTS",
                    rate_code = "1",
                    passenger_count = 1,
                    trip_distance = 10.33f,
                    payment_type = "CSH",
                    fare_amount = 0 // predict it. actual = 29.5
                };
            }
    
            public static async Task<PredictionModel<TaxiTrip, TaxiTripFarePrediction>> Train()
            {
                var pipeline = new LearningPipeline();
    
                pipeline.Add(new TextLoader<TaxiTrip>(DataPath, useHeader: true, separator: ","));
                pipeline.Add(new ColumnCopier(("fare_amount", "Label")));
                pipeline.Add(new CategoricalOneHotVectorizer("vendor_id",
                                                  "rate_code",
                                                  "payment_type"));
                pipeline.Add(new ColumnConcatenator("Features",
                                                    "vendor_id",
                                                    "rate_code",
                                                    "passenger_count",
                                                    "trip_distance",
                                                    "payment_type"));
                pipeline.Add(new FastTreeRegressor());
                PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = pipeline.Train<TaxiTrip, TaxiTripFarePrediction>();
                if (!Directory.Exists(ModelDirectory))
                {
                    Directory.CreateDirectory(ModelDirectory);
                }
                await model.WriteAsync(ModelPath);
                return model;
            }
    
            public static void Evaluate(PredictionModel<TaxiTrip, TaxiTripFarePrediction> model)
            {
                var testData = new TextLoader<TaxiTrip>(TestDataPath, useHeader: true, separator: ",");
                var evaluator = new RegressionEvaluator();
                RegressionMetrics metrics = evaluator.Evaluate(model, testData);
                // Rms should be around 2.795276
                Console.WriteLine("Rms=" + metrics.Rms);
                Console.WriteLine("RSquared = " + metrics.RSquared);
            }
    
            static async Task Main(string[] args)
            {
                PredictionModel<TaxiTrip, TaxiTripFarePrediction> model = await Train();
                Evaluate(model);
    
                var prediction = model.Predict(TestTrips.Trip1);
    
                Console.WriteLine("Predicted fare: {0}, actual fare: 29.5", prediction.fare_amount);
            }
        }
    }

    不知不觉我们的ML.NET之旅又向前进了一步,是不是对于使用.NET Core进行机器学习解决现实生活中的问题更有兴趣了?请保持关注吧。

  • 相关阅读:
    常用知识点集合
    LeetCode 66 Plus One
    LeetCode 88 Merge Sorted Array
    LeetCode 27 Remove Element
    LeetCode 26 Remove Duplicates from Sorted Array
    LeetCode 448 Find All Numbers Disappeared in an Array
    LeetCode 219 Contains Duplicate II
    LeetCode 118 Pascal's Triangle
    LeetCode 119 Pascal's Triangle II
    LeetCode 1 Two Sum
  • 原文地址:https://www.cnblogs.com/BeanHsiang/p/9017618.html
Copyright © 2011-2022 走看看