zoukankan      html  css  js  c++  java
  • TensorFlow.NET机器学习入门【5】采用神经网络实现手写数字识别(MNIST)

     从这篇文章开始,终于要干点正儿八经的工作了,前面都是准备工作。这次我们要解决机器学习的经典问题,MNIST手写数字识别。

    首先介绍一下数据集。请首先解压:TF_Net\Asset\mnist_png.tar.gz文件

     文件夹内包括两个文件夹:training和validation,其中training文件夹下包括60000个训练图片validation下包括10000个评估图片,图片为28*28像素,分别放在0~9十个文件夹中。

    程序总体流程和上一篇文章介绍的BMI分析程序基本一致,毕竟都是多元分类,有几点不一样。

    1、BMI程序的特征数据(输入)为一维数组,包含两个数字,MNIST的特征数据为28*28的二位数组;

    2、BMI程序的输出为3个,MNIST的输出为10个;

    网络模型构建如下:

            private readonly int img_rows = 28;
            private readonly int img_cols = 28;
            private readonly int num_classes = 10;  // total classes
            /// <summary>
            /// 构建网络模型
            /// </summary>     
            private Model BuildModel()
            {
                // 网络参数          
                int n_hidden_1 = 128;    // 1st layer number of neurons.     
                int n_hidden_2 = 128;    // 2nd layer number of neurons.                                
                float scale = 1.0f / 255;
    
                var model = keras.Sequential(new List<ILayer>
                {
                    keras.layers.InputLayer((img_rows,img_cols)),
                    keras.layers.Flatten(),
                    keras.layers.Rescaling(scale),
                    keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
                    keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
                    keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
                });
    
                return model;
            }

    这个网络里用到了两个新方法,需要解释一下:

    1、Flatten方法:这里表示拉平,把28*28的二维数组拉平为含784个数据的一维数组,因为二维数组无法进行运算;

    2、Rescaling 方法:就是对每个数据乘以一个系数,因为我们从图片获取的数据为每一个位点的灰度值,其取值范围为0~255,所以乘以一个系数将数据缩小到1以内,以免后面运算时溢出。

    其它基本和上一篇文章介绍的差不多,全部代码如下:

     /// <summary>
        /// 神经网络实现手写数字识别
        /// </summary>
        public class NN_MultipleClassification_MNIST
        {
            private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\train_data.bin";
            private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\train_label.bin";
    
            private readonly int img_rows = 28;
            private readonly int img_cols = 28;
            private readonly int num_classes = 10;  // total classes
    
            public void Run()
            {
                var model = BuildModel();
                model.summary();
    
                model.compile(optimizer: keras.optimizers.Adam(0.001f),
                    loss: keras.losses.SparseCategoricalCrossentropy(),
                    metrics: new[] { "accuracy" });
    
                (NDArray train_x, NDArray train_y) = LoadTrainingData();
                model.fit(train_x, train_y, batch_size: 1024, epochs: 10);
    
                test(model);
            }
    
            /// <summary>
            /// 构建网络模型
            /// </summary>     
            private Model BuildModel()
            {
                // 网络参数          
                int n_hidden_1 = 128;    // 1st layer number of neurons.     
                int n_hidden_2 = 128;    // 2nd layer number of neurons.                                
                float scale = 1.0f / 255;
    
                var model = keras.Sequential(new List<ILayer>
                {
                    keras.layers.InputLayer((img_rows,img_cols)),
                    keras.layers.Flatten(),
                    keras.layers.Rescaling(scale),
                    keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
                    keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
                    keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
                });
    
                return model;
            }
    
            /// <summary>
            /// 加载训练数据
            /// </summary>
            /// <param name="total_size"></param>    
            private (NDArray, NDArray) LoadTrainingData()
            {
                try
                {
                    Console.WriteLine("Load data");
                    IFormatter serializer = new BinaryFormatter();
                    FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read);
                    float[,,] arrx = serializer.Deserialize(loadFile) as float[,,];
    
                    loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read);
                    int[] arry = serializer.Deserialize(loadFile) as int[];
                    Console.WriteLine("Load data success");
                    return (np.array(arrx), np.array(arry));
                }
                catch (Exception ex)
                {
                    Console.WriteLine($"Load data Exception:{ex.Message}");
                    return LoadRawData();
                }
            }
    
            private (NDArray, NDArray) LoadRawData()
            {
                Console.WriteLine("LoadRawData");
    
                int total_size = 60000;
                float[,,] arrx = new float[total_size, img_rows, img_cols];
                int[] arry = new int[total_size];
    
                int count = 0;
                var TrainingImagePath = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\training";
                DirectoryInfo RootDir = new DirectoryInfo(TrainingImagePath);
                foreach (var Dir in RootDir.GetDirectories())
                {
                    foreach (var file in Dir.GetFiles("*.png"))
                    {
                        Bitmap bmp = (Bitmap)Image.FromFile(file.FullName);
                        if (bmp.Width != img_cols || bmp.Height != img_rows)
                        {
                            continue;
                        }
    
                        for (int row = 0; row < img_rows; row++)
                            for (int col = 0; col < img_cols; col++)
                            {
                                var pixel = bmp.GetPixel(col, row);
                                int val = (pixel.R + pixel.G + pixel.B) / 3;
    
                                arrx[count, row, col] = val;
                                arry[count] = int.Parse(Dir.Name);
                            }
    
                        count++;
                    }
    
                    Console.WriteLine($"Load image data count={count}");
                }
    
                Console.WriteLine("LoadRawData finished");
                //Save Data
                Console.WriteLine("Save data");
                IFormatter serializer = new BinaryFormatter();
    
                //开始序列化
                FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write);
                serializer.Serialize(saveFile, arrx);
                saveFile.Close();
    
                saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write);
                serializer.Serialize(saveFile, arry);
                saveFile.Close();
                Console.WriteLine("Save data finished");
    
                return (np.array(arrx), np.array(arry));
            }
    
            /// <summary>
            /// 消费模型
            /// </summary>      
            private void test(Model model)
            {
                var TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\test";
                DirectoryInfo TestDir = new DirectoryInfo(TestImagePath);
                foreach (var image in TestDir.GetFiles("*.png"))
                {
                    var x = LoadImage(image.FullName);
                    var pred_y = model.Apply(x);
                    var result = argmax(pred_y[0].numpy());
    
                    Console.WriteLine($"FileName:{image.Name}\tPred:{result}");
                }
            }
    
            private NDArray LoadImage(string filename)
            {
                float[,,] arrx = new float[1, img_rows, img_cols];
                Bitmap bmp = (Bitmap)Image.FromFile(filename);
    
                for (int row = 0; row < img_rows; row++)
                    for (int col = 0; col < img_cols; col++)
                    {
                        var pixel = bmp.GetPixel(col, row);
                        int val = (pixel.R + pixel.G + pixel.B) / 3;
                        arrx[0, row, col] = val;
                    }
    
                return np.array(arrx);
            }
    
            private int argmax(NDArray array)
            {
                var arr = array.reshape(-1);
    
                float max = 0;
                for (int i = 0; i < 10; i++)
                {
                    if (arr[i] > max)
                    {
                        max = arr[i];
                    }
                }
    
                for (int i = 0; i < 10; i++)
                {
                    if (arr[i] == max)
                    {
                        return i;
                    }
                }
    
                return 0;
            }
        }
    View Code

     另有两点说明:

    1、由于对图片的读取比较耗时,所以我采用了一个方法,就是把读取到的数据序列化到一个二进制文件中,下次直接从二进制文件反序列化即可,大大加快处理速度。如果找不到bin文件就从图片读取,bin文件我没有上传到git库里,所以下载项目后第一次运行需要一点时间。

    2、我没有采用validation图片进行评估,只是简单选了20个样本测试了一下。

    【相关资源】

     源码:Git: https://gitee.com/seabluescn/tf_not.git

    项目名称:NN_MultipleClassification_MNIST

    目录:查看TensorFlow.NET机器学习入门系列目录


    签名区:
    如果您觉得这篇博客对您有帮助或启发,请点击右侧【推荐】支持,谢谢!
  • 相关阅读:
    [转]我们应该做什么样的研究
    [转]面向服务架构(SOA)和企业服务总线(ESB)
    [转]程序员应知——团队精神
    vs2010 调试快捷键
    asp.net 获取ip的6种方法
    解决了防止用户重复登陆和session超时
    IE 10 也能随网站应变,图标决定一切!
    Sony VAIO Duo 11 游戏性能测试
    翻出Windows 8 当中的游戏管理器
    Office 2013预览版已到期,需要付费才可正常使用
  • 原文地址:https://www.cnblogs.com/seabluescn/p/15592834.html
Copyright © 2011-2022 走看看