zoukankan      html  css  js  c++  java
  • 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

     一、概述

    通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下:

    之前介绍过,图片的特征是不能采用像素的灰度值的,这部分原理的台阶有点高,还好可以直接使用通过TensorFlow训练过的特征提取模型(美其名曰迁移学习)。

    模型文件为:tensorflow_inception_graph.pb

    二、样本介绍

     我随便在网上找了一些图片,分成6类:男孩、女孩、猫、狗、男人、女人。tags文件标记了每个文件所代表的类型标签(Label)。

    通过对这六类图片的学习,期望输入新的图片时,可以判断出是何种类型。

    三、代码

     全部代码:

    namespace TensorFlow_ImageClassification
    {    
    
        class Program
        {
            //Assets files download from:https://gitee.com/seabluescn/ML_Assets
            static readonly string AssetsFolder = @"D:StepByStepBlogsML_Assets";
            static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "ImageClassification", "train");
            static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "ImageClassification", "train_tags.tsv");
            static readonly string TestDataFolder = Path.Combine(AssetsFolder, "ImageClassification","test");
            static readonly string inceptionPb = Path.Combine(AssetsFolder, "TensorFlow", "tensorflow_inception_graph.pb");
            static readonly string imageClassifierZip = Path.Combine(Environment.CurrentDirectory, "MLModel", "imageClassifier.zip");
    
            //配置用常量
            private struct ImageNetSettings
            {
                public const int imageHeight = 224;
                public const int imageWidth = 224;
                public const float mean = 117;
                public const float scale = 1;
                public const bool channelsLast = true;
            }
    
            static void Main(string[] args)
            {
                TrainAndSaveModel();
                LoadAndPrediction();
    
                Console.WriteLine("Hit any key to finish the app");
                Console.ReadKey();
            }
    
            public static void TrainAndSaveModel()
            {
                MLContext mlContext = new MLContext(seed: 1);
    
                // STEP 1: 准备数据
                var fulldata = mlContext.Data.LoadFromTextFile<ImageNetData>(path: TrainTagsPath, separatorChar: '	', hasHeader: false);
    
                var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1);
                var trainData = trainTestData.TrainSet;
                var testData = trainTestData.TestSet;
    
                // STEP 2:创建学习管道
                var pipeline = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelTokey", inputColumnName: "Label")
                    .Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: TrainDataFolder, inputColumnName: nameof(ImageNetData.ImagePath)))
                    .Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "input"))
                    .Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageNetSettings.channelsLast, offsetImage: ImageNetSettings.mean))
                    .Append(mlContext.Model.LoadTensorFlowModel(inceptionPb).
                         ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
                    .Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelTokey", featureColumnName: "softmax2_pre_activation"))
                    .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))
                    .AppendCacheCheckpoint(mlContext);
    
                // STEP 3:通过训练数据调整模型    
                ITransformer model = pipeline.Fit(trainData);
    
                // STEP 4:评估模型
                Console.WriteLine("===== Evaluate model =======");
                var evaData = model.Transform(testData);
                var metrics = mlContext.MulticlassClassification.Evaluate(evaData, labelColumnName: "LabelTokey", predictedLabelColumnName: "PredictedLabel");
                PrintMultiClassClassificationMetrics(metrics);
    
                //STEP 5:保存模型
                Console.WriteLine("====== Save model to local file =========");
                mlContext.Model.Save(model, trainData.Schema, imageClassifierZip);
            }
    
            static void LoadAndPrediction()
            {
                MLContext mlContext = new MLContext(seed: 1);
    
                // Load the model
                ITransformer loadedModel = mlContext.Model.Load(imageClassifierZip, out var modelInputSchema);
    
                // Make prediction function (input = ImageNetData, output = ImageNetPrediction)
                var predictor = mlContext.Model.CreatePredictionEngine<ImageNetData, ImageNetPrediction>(loadedModel);
                
                DirectoryInfo testdir = new DirectoryInfo(TestDataFolder);
                foreach (var jpgfile in testdir.GetFiles("*.jpg"))
                {
                    ImageNetData image = new ImageNetData();
                    image.ImagePath = jpgfile.FullName;
                    var pred = predictor.Predict(image);
    
                    Console.WriteLine($"Filename:{jpgfile.Name}:	Predict Result:{pred.PredictedLabelValue}");
                }
            }       
        }
    
        public class ImageNetData
        {
            [LoadColumn(0)]
            public string ImagePath;
    
            [LoadColumn(1)]
            public string Label;
        }
    
        public class ImageNetPrediction
        {
            //public float[] Score;
            public string PredictedLabelValue;
        }   
    }
    View Code

      

    四、分析

     1、数据处理通道

    可以看出,其代码流程与结构与上两篇文章介绍的完全一致,这里就介绍一下核心的数据处理模型部分的代码:

    var pipeline = mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: "LabelTokey", inputColumnName: "Label")
      .Append(mlContext.Transforms.LoadImages(outputColumnName: "input", imageFolder: TrainDataFolder, inputColumnName: nameof(ImageNetData.ImagePath)))
      .Append(mlContext.Transforms.ResizeImages(outputColumnName: "input", imageWidth: ImageNetSettings.imageWidth, imageHeight: ImageNetSettings.imageHeight, inputColumnName: "input"))
      .Append(mlContext.Transforms.ExtractPixels(outputColumnName: "input", interleavePixelColors: ImageNetSettings.channelsLast, offsetImage: ImageNetSettings.mean))
      .Append(mlContext.Model.LoadTensorFlowModel(inceptionPb).
              ScoreTensorFlowModel(outputColumnNames: new[] { "softmax2_pre_activation" }, inputColumnNames: new[] { "input" }, addBatchDimensionInput: true))
      .Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "LabelTokey", featureColumnName: "softmax2_pre_activation"))
      .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabelValue", "PredictedLabel"))

    MapValueToKey与MapKeyToValue之前已经介绍过了;
    LoadImages是读取文件,输入为文件名、输出为Image;
    ResizeImages是改变图片尺寸,这一步是必须的,即使所有训练图片都是标准划一的图片也需要这个操作,后面需要根据这个尺寸确定容纳图片像素信息的数组大小;
    ExtractPixels是将图片转换为包含像素数据的矩阵;
    LoadTensorFlowModel是加载第三方模型,ScoreTensorFlowModel是调用模型处理数据,其输入为:“input”,输出为:“softmax2_pre_activation”,由于模型中输入、输出的名称是规定的,所以,这里的名称不可以随便修改。
    分类算法采用的是L-BFGS最大熵分类算法,其特征数据为TensorFlow网络输出的值,标签值为"LabelTokey"。

    2、验证过程
                MLContext mlContext = new MLContext(seed: 1);
                ITransformer loadedModel = mlContext.Model.Load(imageClassifierZip, out var modelInputSchema);           
                var predictor = mlContext.Model.CreatePredictionEngine<ImageNetData, ImageNetPrediction>(loadedModel);
                            
                ImageNetData image = new ImageNetData();
                image.ImagePath = jpgfile.FullName;
                var pred = predictor.Predict(image);
                Console.WriteLine($"Filename:{jpgfile.Name}:	Predict Result:{pred.PredictedLabelValue}");

     两个实体类代码:

        public class ImageNetData
        {
            [LoadColumn(0)]
            public string ImagePath;
    
            [LoadColumn(1)]
            public string Label;
        }
    
        public class ImageNetPrediction
        {       
            public string PredictedLabelValue;
        } 
    3、验证结果
    我在网络上又随便找了20张图片进行验证,发现验证成功率是非常高的,基本都是准确的,只有两个出错了。

    上图片被识别为girl(我认为是Woman),这个情有可原,本来girl和worman在外貌上也没有一个明确的分界点。

    上图被识别为woman,这个也情有可原,不解释。

    需要了解的是:不管你输入什么图片,最终的结果只能是以上六个类型之一,算法会寻找到和六个分类中特征最接近的一个分类作为结果。


    4、调试
    注意看实体类的话,我们只提供了三个基本属性,如果想看一下在学习过程中数据是如何处理的,可以给ImageNetPrediction类增加一些字段进行调试。
    首先我们需要看一下IDateView有哪些列(Column)
                var predictions = trainedModel.Transform(testData);          
    
                var OutputColumnNames = predictions.Schema.Where(col => !col.IsHidden).Select(col => col.Name);
                foreach (string column in OutputColumnNames)
                {
                    Console.WriteLine($"OutputColumnName:{ column }");
                }

     将我们要调试的列加入到实体对象中去,特别要注意数据类型。

        public class ImageNetPrediction
        {
            public float[] Score;
            public string PredictedLabelValue; 
            public string Label;
           
            public void PrintToConsole()
            {
                //打印字段信息
            }
        }  

     查看数据集详细信息:

               var predictions = trainedModel.Transform(testData); 
                var DataShowList = new List<ImageNetPrediction>(mlContext.Data.CreateEnumerable<ImageNetPrediction>(predictions, false, true));
               foreach (var dataline in DataShowList)
                {                
                        dataline.PrintToConsole();                               
                }
    
    

    五、资源获取 

    源码下载地址:https://github.com/seabluescn/Study_ML.NET

    工程名称:TensorFlow_ImageClassification

    资源获取:https://gitee.com/seabluescn/ML_Assets

    点击查看机器学习框架ML.NET学习笔记系列文章目录

  • 相关阅读:
    GTK+ 3.6.2 发布,小的 bug 修复版本
    RunJS 新增 Echo Ajax 测试功能
    Mozilla 发布 Popcorn Maker,在线创作视频
    Sina微博OAuth2框架解密
    Mina状态机State Machine
    Mozilla 发布 Shumway —— 纯JS的SWF解析器
    Code Browser 4.5 发布,代码浏览器
    ROSA 2012 "Enterprise Linux Server" 发布
    ltrace 0.7.0 发布,程序调试工具
    Artifactory 2.6.5 发布,Maven 扩展工具
  • 原文地址:https://www.cnblogs.com/seabluescn/p/10944579.html
Copyright © 2011-2022 走看看