zoukankan      html  css  js  c++  java
  • 【OpenCV】opencv3.0中的SVM训练 mnist 手写字体识别

    前言:

    SVM(支持向量机)一种训练分类器的学习方法

    mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个

    LibSVM 一个常用的SVM框架

    OpenCV3.0 中的ml包含了很多的ML框架接口,就试试了。

    详细的OpenCV文档:http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html

    mnist数据下载:http://yann.lecun.com/exdb/mnist/

    LibSVM下载:http://www.csie.ntu.edu.tw/~cjlin/libsvm/

    ========================我是分割线=============================

    训练的过程大致如下:

    1. 读取mnist训练集数据

    2. 训练

    3. 读取mnist测试数据,对比预测结果,得到错误率

    具体实现:

    1. mnist给出的数据文件是二进制文件

        四个文件,解压后如下

       

      "train-images.idx3-ubyte" 二进制文件,存储了头文件信息以及60000张28*28图像pixel信息(用于训练)
      "train-labels.idx1-ubyte" 二进制文件,存储了头文件信息以及60000张图像label信息
      "t10k-images.idx3-ubyte"二进制文件,存储了头文件信息以及10000张28*28图像pixel信息(用于测试)
      "t10k-labels.idx1-ubyte"二进制文件,存储了头文件信息以及10000张图像label信息

      因为OpenCV中没有直接导入MINST数据的文件,所以需要自己写函数来读取

      首先要知道,MNIST数据的数据格式

      

       IMAGE FILE包含四个int型的头部数据(magic number,number_of_images, number_of_rows, number_of_columns)

           余下的每一个byte表示一个pixel的数据,范围是0-255(可以在读入的时候scale到0~1的区间

           LABEL FILE包含两个int型的头部数据(magic number, number of items)

           余下的每一个byte表示一个label数据,范围是0-9

       注意(第一个坑):MNIST是大端存储,然而大部分的Intel处理器都是小端存储,所以对于int、long、float这些多字节的数据类型,就要一个一个byte地翻转过来,才能正确显示。

      

     1 //翻转
     2 int reverseInt(int i) {
     3     unsigned char c1, c2, c3, c4;
     4 
     5     c1 = i & 255;
     6     c2 = (i >> 8) & 255;
     7     c3 = (i >> 16) & 255;
     8     c4 = (i >> 24) & 255;
     9 
    10     return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
    11 }
    View Code

      然后读取MNIST文件,但是它是二进制文件,打开方式

      所以不能用

      ifstream file(fileName);

      而要改成

      ifstream file(fileName, ios::binary);

      注意(第二个坑):如果用第一条指令来打开文件,不会报错,但是数据会出现错误,头部数据仍然正确,但是后面的pixel数据大部分都是0,我刚开始没注意,开始training的时候发现等了很久...真的是很久...(7+ hours)...估计是达到迭代终止的最大次数了,才停下来的

      嗯,stack overflow上也有类似的提问:

      

      注意(第三个坑):

      training时,IMAGE和LABEL的数据分别都放进一个MAT中存储,但是只能是CV32_F或者CV32_S的格式,不然会assertion报错

      OPENCV给出的文档中,例子是这样的:(但是predict的时候又会要求label的格式是unsigned int)所以...可以设置data的Mat格式为CV_32FC1,label的Mat格式为CV_32SC1

      

      顺便地,图像训练数据的转换存储格式(http://stackoverflow.com/questions/14694810/using-opencv-and-svm-with-images?rq=1)

      

      最后,为了验证读取数据的正确性,一个有效的办法就是输出第一个和最后一个数据(可以输出打印第一个/最后一个image以及label)

    2. 训练

      (此处我是直接对原图像训练,并没有提取任何的特征)

      也有人建议这里应该对图像做HOG特征提取,再配合label训练(我还没试过...不知道效果如何...)

      

      opencv3.0和2.4的SVM接口有不同,基本可以按照以下的格式来执行:

    ml::SVM::Params params;
    params.svmType = ml::SVM::C_SVC;
    params.kernelType = ml::SVM::POLY;
    params.gamma = 3;
    Ptr<ml::SVM> svm = ml::SVM::create(params);
    Mat trainData; // 每行为一个样本
    Mat labels;    
    svm->train( trainData , ml::ROW_SAMPLE , labels );
    // ...
    
    svm->save("....");//文件形式为xml,可以保存在txt或者xml文件中
    Ptr<SVM> svm=statModel::load<SVM>("....");
    
    Mat query; // 输入, 1个通道
    Mat res;   // 输出
    svm->predict(query, res);

     但是要注意,如果报错的话最好去看opencv3.0的文档,里面有函数原型和解释,我在实际操作的过程中,也做了一些改动

       1)设置参数

        SVM的参数有很多,但是与C_SVC和RBF有关的就只有gamma和C,所以设置这两个就好,终止条件设置和默认一样,由经验可得(其实是查阅了很多的资料,把gamma设置成0.01,这样训练收敛速度会快很多)

    Ptr<SVM> svm = SVM::create();
    svm->setType(SVM::C_SVC);
    svm->setKernel(SVM::RBF);
    svm->setGamma(0.01);
    svm->setC(10.0);
    svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000,FLT_EPSILON));

      svm_type –指定SVM的类型,下面是可能的取值:

      CvSVM::C_SVC C类支持向量分类机。 n类分组 (n geq 2),允许用异常值惩罚因子C进行不完全分类。
      CvSVM::NU_SVC u类支持向量分类机。n类似然不完全分类的分类器。参数为 u 取代C(其值在区间【0,1】中,nu越大,决策边界越平滑)。
      CvSVM::ONE_CLASS 单分类器,所有的训练数据提取自同一个类里,然后SVM建立了一个分界线以分割该类在特征空间中所占区域和其它类在特征空间中所占区域。
      CvSVM::EPS_SVR epsilon类支持向量回归机。训练集中的特征向量和拟合出来的超平面的距离需要小于p。异常值惩罚因子C被采用。
      CvSVM::NU_SVR u类支持向量回归机。 u 代替了 p。

      kernel_type –SVM的内核类型,下面是可能的取值:

      CvSVM::LINEAR 线性内核。没有任何向映射至高维空间,线性区分(或回归)在原始特征空间中被完成,这是最快的选择。K(x_i, x_j) = x_i^T x_j.
      CvSVM::POLY 多项式内核: K(x_i, x_j) = (gamma x_i^T x_j + coef0)^{degree}, gamma > 0.
      CvSVM::RBF 基于径向的函数,对于大多数情况都是一个较好的选择: K(x_i, x_j) = e^{-gamma ||x_i - x_j||^2}, gamma > 0.
      CvSVM::SIGMOID Sigmoid函数内核:K(x_i, x_j) = anh(gamma x_i^T x_j + coef0).

      degree – 内核函数(POLY)的参数degree。

      gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数gamma。

      coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。

      Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。

      nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 u。

      p – SVM类型(EPS_SVR)的参数 epsilon。

      class_weights – C_SVC中的可选权重,赋给指定的类,乘以C以后变成 class\_weights_i * C。所以这些权重影响不同类别的错误分类惩罚项。权重越大,某一类别的误分类数据的惩罚项就越大。

      term_crit – SVM的迭代训练过程的中止条件,解决部分受约束二次最优问题。您可以指定的公差和/或最大迭代次数。

       2)训练

    Mat trainData;
    Mat labels;
    trainData = read_mnist_image(trainImage);
    labels = read_mnist_label(trainLabel);
    
    svm->train(trainData, ROW_SAMPLE, labels);

       3)保存

    svm->save("mnist_dataset/mnist_svm.xml");

    3. 测试,比对结果

     (此处的FLT_EPSILON是一个极小的数,1.0 - FLT_EPSILON != 1.0)

    Mat testData;
    Mat tLabel;
    testData = read_mnist_image(testImage);
    tLabel = read_mnist_label(testLabel);
    
    float count = 0;
    for (int i = 0; i < testData.rows; i++) {
        Mat sample = testData.row(i);
        float res = svm1->predict(sample);
        res = std::abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;
        count += res;
    }
    cout << "正确的识别个数 count = " << count << endl;
    cout << "错误率为..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....
    ";

    这里没有使用svm->predict(query, res);

    然后就查看了opencv的文档,当传入数据是Mat 而不是cvMat时,可以利用predict的返回值(float)来判断预测是否正确。

    运行结果:

    11000个训练数据/1000个测试数据

      

    2)2000个训练数据/2000个测试数据

      

    3)5000个训练数据/5000个测试数据

      

    410000个训练数据/10000个测试数据

      

    5)60000个训练数据/10000个测试数据

      

    最后,关于运行时间(在程序正确的前提下,训练时长和初始的参数设置有关),给出我最的运行结果(1000张图是11s左右,60000张是1300s ~ 2000s左右)

    代码:

     1 #ifndef MNIST_H  
     2 #define MNIST_H
     3 
     4 #include <iostream>
     5 #include <string>
     6 #include <fstream>
     7 #include <ctime>
     8 #include <opencv2/opencv.hpp>  
     9 
    10 using namespace cv;
    11 using namespace std;
    12 
    13 //小端存储转换
    14 int reverseInt(int i);
    15 
    16 //读取image数据集信息
    17 Mat read_mnist_image(const string fileName);
    18 
    19 //读取label数据集信息
    20 Mat read_mnist_label(const string fileName);
    21 
    22 #endif
    mnist.h
      1 #include "mnist.h"
      2 
      3 //计时器
      4 double cost_time;
      5 clock_t start_time;
      6 clock_t end_time;
      7 
      8 //测试item个数
      9 int testNum = 10000;
     10 
     11 int reverseInt(int i) {
     12     unsigned char c1, c2, c3, c4;
     13 
     14     c1 = i & 255;
     15     c2 = (i >> 8) & 255;
     16     c3 = (i >> 16) & 255;
     17     c4 = (i >> 24) & 255;
     18 
     19     return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
     20 }
     21 
     22 Mat read_mnist_image(const string fileName) {
     23     int magic_number = 0;
     24     int number_of_images = 0;
     25     int n_rows = 0;
     26     int n_cols = 0;
     27 
     28     Mat DataMat;
     29 
     30     ifstream file(fileName, ios::binary);
     31     if (file.is_open())
     32     {
     33         cout << "成功打开图像集 ... 
    ";
     34 
     35         file.read((char*)&magic_number, sizeof(magic_number));
     36         file.read((char*)&number_of_images, sizeof(number_of_images));
     37         file.read((char*)&n_rows, sizeof(n_rows));
     38         file.read((char*)&n_cols, sizeof(n_cols));
     39         //cout << magic_number << " " << number_of_images << " " << n_rows << " " << n_cols << endl;
     40 
     41         magic_number = reverseInt(magic_number);
     42         number_of_images = reverseInt(number_of_images);
     43         n_rows = reverseInt(n_rows);
     44         n_cols = reverseInt(n_cols);
     45         cout << "MAGIC NUMBER = " << magic_number
     46             << " ;NUMBER OF IMAGES = " << number_of_images
     47             << " ; NUMBER OF ROWS = " << n_rows
     48             << " ; NUMBER OF COLS = " << n_cols << endl;
     49 
     50         //-test-
     51         //number_of_images = testNum;
     52         //输出第一张和最后一张图,检测读取数据无误
     53         Mat s = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
     54         Mat e = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
     55 
     56         cout << "开始读取Image数据......
    ";
     57         start_time = clock();
     58         DataMat = Mat::zeros(number_of_images, n_rows * n_cols, CV_32FC1);
     59         for (int i = 0; i < number_of_images; i++) {
     60             for (int j = 0; j < n_rows * n_cols; j++) {
     61                 unsigned char temp = 0;
     62                 file.read((char*)&temp, sizeof(temp));
     63                 float pixel_value = float((temp + 0.0) / 255.0);
     64                 DataMat.at<float>(i, j) = pixel_value;
     65 
     66                 //打印第一张和最后一张图像数据
     67                 if (i == 0) {
     68                     s.at<float>(j / n_cols, j % n_cols) = pixel_value;
     69                 }
     70                 else if (i == number_of_images - 1) {
     71                     e.at<float>(j / n_cols, j % n_cols) = pixel_value;
     72                 }
     73             }
     74         }
     75         end_time = clock();
     76         cost_time = (end_time - start_time) / CLOCKS_PER_SEC;
     77         cout << "读取Image数据完毕......" << cost_time << "s
    ";
     78 
     79         imshow("first image", s);
     80         imshow("last image", e);
     81         waitKey(0);
     82     }
     83     file.close();
     84     return DataMat;
     85 }
     86 
     87 Mat read_mnist_label(const string fileName) {
     88     int magic_number;
     89     int number_of_items;
     90 
     91     Mat LabelMat;
     92 
     93     ifstream file(fileName, ios::binary);
     94     if (file.is_open())
     95     {
     96         cout << "成功打开Label集 ... 
    ";
     97 
     98         file.read((char*)&magic_number, sizeof(magic_number));
     99         file.read((char*)&number_of_items, sizeof(number_of_items));
    100         magic_number = reverseInt(magic_number);
    101         number_of_items = reverseInt(number_of_items);
    102 
    103         cout << "MAGIC NUMBER = " << magic_number << "  ; NUMBER OF ITEMS = " << number_of_items << endl;
    104 
    105         //-test-
    106         //number_of_items = testNum;
    107         //记录第一个label和最后一个label
    108         unsigned int s = 0, e = 0;
    109 
    110         cout << "开始读取Label数据......
    ";
    111         start_time = clock();
    112         LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
    113         for (int i = 0; i < number_of_items; i++) {
    114             unsigned char temp = 0;
    115             file.read((char*)&temp, sizeof(temp));
    116             LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
    117 
    118             //打印第一个和最后一个label
    119             if (i == 0) s = (unsigned int)temp;
    120             else if (i == number_of_items - 1) e = (unsigned int)temp;
    121         }
    122         end_time = clock();
    123         cost_time = (end_time - start_time) / CLOCKS_PER_SEC;
    124         cout << "读取Label数据完毕......" << cost_time << "s
    ";
    125 
    126         cout << "first label = " << s << endl;
    127         cout << "last label = " << e << endl;
    128     }
    129     file.close();
    130     return LabelMat;
    131 }
    mnist.cpp
      1 /*
      2 svm_type –
      3 指定SVM的类型,下面是可能的取值:
      4 CvSVM::C_SVC C类支持向量分类机。 n类分组  (n geq 2),允许用异常值惩罚因子C进行不完全分类。
      5 CvSVM::NU_SVC 
    u类支持向量分类机。n类似然不完全分类的分类器。参数为 
    u 取代C(其值在区间【0,1】中,nu越大,决策边界越平滑)。
      6 CvSVM::ONE_CLASS 单分类器,所有的训练数据提取自同一个类里,然后SVM建立了一个分界线以分割该类在特征空间中所占区域和其它类在特征空间中所占区域。
      7 CvSVM::EPS_SVR epsilon类支持向量回归机。训练集中的特征向量和拟合出来的超平面的距离需要小于p。异常值惩罚因子C被采用。
      8 CvSVM::NU_SVR 
    u类支持向量回归机。 
    u 代替了 p。
      9 
     10 可从 [LibSVM] 获取更多细节。
     11 
     12 kernel_type –
     13 SVM的内核类型,下面是可能的取值:
     14 CvSVM::LINEAR 线性内核。没有任何向映射至高维空间,线性区分(或回归)在原始特征空间中被完成,这是最快的选择。K(x_i, x_j) = x_i^T x_j.
     15 CvSVM::POLY 多项式内核: K(x_i, x_j) = (gamma x_i^T x_j + coef0)^{degree}, gamma > 0.
     16 CvSVM::RBF 基于径向的函数,对于大多数情况都是一个较好的选择: K(x_i, x_j) = e^{-gamma ||x_i - x_j||^2}, gamma > 0.
     17 CvSVM::SIGMOID Sigmoid函数内核:K(x_i, x_j) = 	anh(gamma x_i^T x_j + coef0).
     18 
     19 degree – 内核函数(POLY)的参数degree。
     20 
     21 gamma – 内核函数(POLY/ RBF/ SIGMOID)的参数gamma。
     22 
     23 coef0 – 内核函数(POLY/ SIGMOID)的参数coef0。
     24 
     25 Cvalue – SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C。
     26 
     27 nu – SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 
    u。
     28 
     29 p – SVM类型(EPS_SVR)的参数 epsilon。
     30 
     31 class_weights – C_SVC中的可选权重,赋给指定的类,乘以C以后变成 class\_weights_i * C。所以这些权重影响不同类别的错误分类惩罚项。权重越大,某一类别的误分类数据的惩罚项就越大。
     32 
     33 term_crit – SVM的迭代训练过程的中止条件,解决部分受约束二次最优问题。您可以指定的公差和/或最大迭代次数。
     34 
     35 */
     36 
     37 
     38 #include "mnist.h"
     39 
     40 #include <opencv2/core.hpp>
     41 #include <opencv2/imgproc.hpp>
     42 #include "opencv2/imgcodecs.hpp"
     43 #include <opencv2/highgui.hpp>
     44 #include <opencv2/ml.hpp>
     45 
     46 #include <string>
     47 #include <iostream>
     48 
     49 using namespace std;
     50 using namespace cv;
     51 using namespace cv::ml;
     52 
     53 string trainImage = "mnist_dataset/train-images.idx3-ubyte";
     54 string trainLabel = "mnist_dataset/train-labels.idx1-ubyte";
     55 string testImage = "mnist_dataset/t10k-images.idx3-ubyte";
     56 string testLabel = "mnist_dataset/t10k-labels.idx1-ubyte";
     57 //string testImage = "mnist_dataset/train-images.idx3-ubyte";
     58 //string testLabel = "mnist_dataset/train-labels.idx1-ubyte";
     59 
     60 //计时器
     61 double cost_time_;
     62 clock_t start_time_;
     63 clock_t end_time_;
     64 
     65 int main()
     66 {
     67     
     68     //--------------------- 1. Set up training data ---------------------------------------
     69     Mat trainData;
     70     Mat labels;
     71     trainData = read_mnist_image(trainImage);
     72     labels = read_mnist_label(trainLabel);
     73 
     74     cout << trainData.rows << " " << trainData.cols << endl;
     75     cout << labels.rows << " " << labels.cols << endl;
     76 
     77     //------------------------ 2. Set up the support vector machines parameters --------------------
     78     Ptr<SVM> svm = SVM::create();
     79     svm->setType(SVM::C_SVC);
     80     svm->setKernel(SVM::RBF);
     81     //svm->setDegree(10.0);
     82     svm->setGamma(0.01);
     83     //svm->setCoef0(1.0);
     84     svm->setC(10.0);
     85     //svm->setNu(0.5);
     86     //svm->setP(0.1);
     87     svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
     88 
     89     //------------------------ 3. Train the svm ----------------------------------------------------
     90     cout << "Starting training process" << endl;
     91     start_time_ = clock();
     92     svm->train(trainData, ROW_SAMPLE, labels);
     93     end_time_ = clock();
     94     cost_time_ = (end_time_ - start_time_) / CLOCKS_PER_SEC;
     95     cout << "Finished training process...cost " << cost_time_ << " seconds..." << endl;
     96     
     97     //------------------------ 4. save the svm ----------------------------------------------------
     98     svm->save("mnist_dataset/mnist_svm.xml");
     99     cout << "save as /mnist_dataset/mnist_svm.xml" << endl;
    100 
    101     
    102     //------------------------ 5. load the svm ----------------------------------------------------
    103     cout << "开始导入SVM文件...
    ";
    104     Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_dataset/mnist_svm.xml");
    105     cout << "成功导入SVM文件...
    ";
    106 
    107 
    108     //------------------------ 6. read the test dataset -------------------------------------------
    109     cout << "开始导入测试数据...
    ";
    110     Mat testData;
    111     Mat tLabel;
    112     testData = read_mnist_image(testImage);
    113     tLabel = read_mnist_label(testLabel);
    114     cout << "成功导入测试数据!!!
    ";
    115 
    116     
    117     float count = 0;
    118     for (int i = 0; i < testData.rows; i++) {
    119         Mat sample = testData.row(i);
    120         float res = svm1->predict(sample);
    121         res = std::abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;
    122         count += res;
    123     }
    124     cout << "正确的识别个数 count = " << count << endl;
    125     cout << "错误率为..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....
    ";
    126     
    127     system("pause");
    128     return 0;
    129 }
    main.cpp

    一些网站(资料):(其实都很容易搜索到的=_=, 但是搬了人家的东西,就还是贴一下...

    http://blog.csdn.net/augusdi/article/details/9005352

    http://blog.csdn.net/arthur503/article/details/19974057

    http://blog.csdn.net/laihonghuan/article/details/49387237

    http://docs.opencv.org/3.0-beta/modules/ml/doc/support_vector_machines.html#prediction-with-svm

    http://stackoverflow.com/questions/14694810/using-opencv-and-svm-with-images?rq=1

    http://docs.opencv.org/2.4/modules/ml/doc/support_vector_machines.html#cvsvm-train

    http://blog.csdn.net/u010869312/article/details/44927721

    http://blog.csdn.net/heroacool/article/details/50579955

    http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html

    http://guyvercz.blog.163.com/blog/static/252545292011112974915402/

    http://stackoverflow.com/questions/12993941/how-can-i-read-the-mnist-dataset-with-c?lq=1

  • 相关阅读:
    Caused by: java.lang.ClassNotFoundException: org.hibernate.annotations.common.reflection.MetadataPro
    Caused by: java.lang.ClassNotFoundException: org.dom4j.DocumentException
    让你的C程序更有效率的10种方法
    c++效率浅析
    Caused by: java.lang.ClassNotFoundException: org.hibernate.engine.FilterDefinition
    Caused by: java.lang.ClassNotFoundException: javax.transaction.TransactionManager
    Caused by: java.lang.NoClassDefFoundError: org/hibernate/cfg/Configuration
    Caused by: java.lang.ClassNotFoundException: org.aspectj.weaver.reflect.ReflectionWorld$ReflectionWo
    Caused by: java.lang.ClassNotFoundException: org.jbpm.pvm.internal.processengine.SpringHelper
    Caused by: java.lang.ClassNotFoundException: org.aopalliance.intercept.MethodInterceptor
  • 原文地址:https://www.cnblogs.com/cheermyang/p/5624333.html
Copyright © 2011-2022 走看看