zoukankan      html  css  js  c++  java
  • 学习OpenCV——SVM 手写数字检测

    转自http://blog.csdn.net/firefight/article/details/6452188

    是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list

    其他方法:http://blog.csdn.net/onezeros/article/details/5672192

    使用OPENCV训练手写数字识别分类器 

    1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
    2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
    3,确定字符特征方式为最简单的8×8网格内的字符点数


    4,创建SVM,训练并读取,结果如下
     1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
      10000个训练样本,测试数据正确率95.45%
      60000个训练样本,测试数据正确率97.67%

    5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

    以下为主要代码,以供参考

    (类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

      1. #include "stdafx.h"   
      2.   
      3. #include <fstream>   
      4. #include "opencv2/opencv.hpp"   
      5. #include <vector>   
      6.   
      7. using namespace std;  
      8. using namespace cv;  
      9.   
      10. #define SHOW_PROCESS 0   
      11. #define ON_STUDY 0   
      12.   
      13. class NumTrainData  
      14. {  
      15. public:  
      16.     NumTrainData()  
      17.     {  
      18.         memset(data, 0, sizeof(data));  
      19.         result = -1;  
      20.     }  
      21. public:  
      22.     float data[64];  
      23.     int result;  
      24. };  
      25.   
      26. vector<NumTrainData> buffer;  
      27. int featureLen = 64;  
      28.   
      29. void swapBuffer(char* buf)  
      30. {  
      31.     char temp;  
      32.     temp = *(buf);  
      33.     *buf = *(buf+3);  
      34.     *(buf+3) = temp;  
      35.   
      36.     temp = *(buf+1);  
      37.     *(buf+1) = *(buf+2);  
      38.     *(buf+2) = temp;  
      39. }  
      40.   
      41. void GetROI(Mat& src, Mat& dst)  
      42. {  
      43.     int left, right, top, bottom;  
      44.     left = src.cols;  
      45.     right = 0;  
      46.     top = src.rows;  
      47.     bottom = 0;  
      48.   
      49.     //Get valid area   
      50.     for(int i=0; i<src.rows; i++)  
      51.     {  
      52.         for(int j=0; j<src.cols; j++)  
      53.         {  
      54.             if(src.at<uchar>(i, j) > 0)  
      55.             {  
      56.                 if(j<left) left = j;  
      57.                 if(j>right) right = j;  
      58.                 if(i<top) top = i;  
      59.                 if(i>bottom) bottom = i;  
      60.             }  
      61.         }  
      62.     }  
      63.   
      64.     //Point center;   
      65.     //center.x = (left + right) / 2;   
      66.     //center.y = (top + bottom) / 2;   
      67.   
      68.     int width = right - left;  
      69.     int height = bottom - top;  
      70.     int len = (width < height) ? height : width;  
      71.   
      72.     //Create a squre   
      73.     dst = Mat::zeros(len, len, CV_8UC1);  
      74.   
      75.     //Copy valid data to squre center   
      76.     Rect dstRect((len - width)/2, (len - height)/2, width, height);  
      77.     Rect srcRect(left, top, width, height);  
      78.     Mat dstROI = dst(dstRect);  
      79.     Mat srcROI = src(srcRect);  
      80.     srcROI.copyTo(dstROI);  
      81. }  
      82.   
      83. int ReadTrainData(int maxCount)  
      84. {  
      85.     //Open image and label file   
      86.     const char fileName[] = "../res/train-images.idx3-ubyte";  
      87.     const char labelFileName[] = "../res/train-labels.idx1-ubyte";  
      88.   
      89.     ifstream lab_ifs(labelFileName, ios_base::binary);  
      90.     ifstream ifs(fileName, ios_base::binary);  
      91.   
      92.     if( ifs.fail() == true )  
      93.         return -1;  
      94.   
      95.     if( lab_ifs.fail() == true )  
      96.         return -1;  
      97.   
      98.     //Read train data number and image rows / cols   
      99.     char magicNum[4], ccount[4], crows[4], ccols[4];  
      100.     ifs.read(magicNum, sizeof(magicNum));  
      101.     ifs.read(ccount, sizeof(ccount));  
      102.     ifs.read(crows, sizeof(crows));  
      103.     ifs.read(ccols, sizeof(ccols));  
      104.   
      105.     int count, rows, cols;  
      106.     swapBuffer(ccount);  
      107.     swapBuffer(crows);  
      108.     swapBuffer(ccols);  
      109.   
      110.     memcpy(&count, ccount, sizeof(count));  
      111.     memcpy(&rows, crows, sizeof(rows));  
      112.     memcpy(&cols, ccols, sizeof(cols));  
      113.   
      114.     //Just skip label header   
      115.     lab_ifs.read(magicNum, sizeof(magicNum));  
      116.     lab_ifs.read(ccount, sizeof(ccount));  
      117.   
      118.     //Create source and show image matrix   
      119.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
      120.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
      121.     Mat img, dst;  
      122.   
      123.     char label = 0;  
      124.     Scalar templateColor(255, 0, 255 );  
      125.   
      126.     NumTrainData rtd;  
      127.   
      128.     //int loop = 1000;   
      129.     int total = 0;  
      130.   
      131.     while(!ifs.eof())  
      132.     {  
      133.         if(total >= count)  
      134.             break;  
      135.           
      136.         total++;  
      137.         cout << total << endl;  
      138.           
      139.         //Read label   
      140.         lab_ifs.read(&label, 1);  
      141.         label = label + '0';  
      142.   
      143.         //Read source data   
      144.         ifs.read((char*)src.data, rows * cols);  
      145.         GetROI(src, dst);  
      146.   
      147. #if(SHOW_PROCESS)   
      148.         //Too small to watch   
      149.         img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);  
      150.         resize(dst, img, img.size());  
      151.   
      152.         stringstream ss;  
      153.         ss << "Number " << label;  
      154.         string text = ss.str();  
      155.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
      156.   
      157.         //imshow("img", img);   
      158. #endif   
      159.   
      160.         rtd.result = label;  
      161.         resize(dst, temp, temp.size());  
      162.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);   
      163.   
      164.         for(int i = 0; i<8; i++)  
      165.         {  
      166.             for(int j = 0; j<8; j++)  
      167.             {  
      168.                     rtd.data[ i*8 + j] = temp.at<uchar>(i, j);  
      169.             }  
      170.         }  
      171.   
      172.         buffer.push_back(rtd);  
      173.   
      174.         //if(waitKey(0)==27) //ESC to quit   
      175.         //  break;   
      176.   
      177.         maxCount--;  
      178.           
      179.         if(maxCount == 0)  
      180.             break;  
      181.     }  
      182.   
      183.     ifs.close();  
      184.     lab_ifs.close();  
      185.   
      186.     return 0;  
      187. }  
      188.   
      189. void newRtStudy(vector<NumTrainData>& trainData)  
      190. {  
      191.     int testCount = trainData.size();  
      192.   
      193.     Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);  
      194.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
      195.   
      196.     for (int i= 0; i< testCount; i++)   
      197.     {   
      198.   
      199.         NumTrainData td = trainData.at(i);  
      200.         memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));  
      201.   
      202.         res.at<unsigned int>(i, 0) = td.result;  
      203.     }  
      204.   
      205.     /////////////START RT TRAINNING//////////////////   
      206.     CvRTrees forest;  
      207.     CvMat* var_importance = 0;  
      208.   
      209.     forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),  
      210.             CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));  
      211.     forest.save( "new_rtrees.xml" );  
      212. }  
      213.   
      214.   
      215. int newRtPredict()  
      216. {  
      217.     CvRTrees forest;  
      218.     forest.load( "new_rtrees.xml" );  
      219.   
      220.     const char fileName[] = "../res/t10k-images.idx3-ubyte";  
      221.     const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";  
      222.   
      223.     ifstream lab_ifs(labelFileName, ios_base::binary);  
      224.     ifstream ifs(fileName, ios_base::binary);  
      225.   
      226.     if( ifs.fail() == true )  
      227.         return -1;  
      228.   
      229.     if( lab_ifs.fail() == true )  
      230.         return -1;  
      231.   
      232.     char magicNum[4], ccount[4], crows[4], ccols[4];  
      233.     ifs.read(magicNum, sizeof(magicNum));  
      234.     ifs.read(ccount, sizeof(ccount));  
      235.     ifs.read(crows, sizeof(crows));  
      236.     ifs.read(ccols, sizeof(ccols));  
      237.   
      238.     int count, rows, cols;  
      239.     swapBuffer(ccount);  
      240.     swapBuffer(crows);  
      241.     swapBuffer(ccols);  
      242.   
      243.     memcpy(&count, ccount, sizeof(count));  
      244.     memcpy(&rows, crows, sizeof(rows));  
      245.     memcpy(&cols, ccols, sizeof(cols));  
      246.   
      247.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
      248.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
      249.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
      250.     Mat img, dst;  
      251.   
      252.     //Just skip label header   
      253.     lab_ifs.read(magicNum, sizeof(magicNum));  
      254.     lab_ifs.read(ccount, sizeof(ccount));  
      255.   
      256.     char label = 0;  
      257.     Scalar templateColor(255, 0, 0);  
      258.   
      259.     NumTrainData rtd;  
      260.   
      261.     int right = 0, error = 0, total = 0;  
      262.     int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;  
      263.     while(ifs.good())  
      264.     {  
      265.         //Read label   
      266.         lab_ifs.read(&label, 1);  
      267.         label = label + '0';  
      268.   
      269.         //Read data   
      270.         ifs.read((char*)src.data, rows * cols);  
      271.         GetROI(src, dst);  
      272.   
      273.         //Too small to watch   
      274.         img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);  
      275.         resize(dst, img, img.size());  
      276.   
      277.         rtd.result = label;  
      278.         resize(dst, temp, temp.size());  
      279.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);   
      280.         for(int i = 0; i<8; i++)  
      281.         {  
      282.             for(int j = 0; j<8; j++)  
      283.             {  
      284.                     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);  
      285.             }  
      286.         }  
      287.   
      288.         if(total >= count)  
      289.             break;  
      290.   
      291.         char ret = (char)forest.predict(m);   
      292.   
      293.         if(ret == label)  
      294.         {  
      295.             right++;  
      296.             if(total <= 5000)  
      297.                 right_1++;  
      298.             else  
      299.                 right_2++;  
      300.         }  
      301.         else  
      302.         {  
      303.             error++;  
      304.             if(total <= 5000)  
      305.                 error_1++;  
      306.             else  
      307.                 error_2++;  
      308.         }  
      309.   
      310.         total++;  
      311.   
      312. #if(SHOW_PROCESS)   
      313.         stringstream ss;  
      314.         ss << "Number " << label << ", predict " << ret;  
      315.         string text = ss.str();  
      316.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
      317.   
      318.         imshow("img", img);  
      319.         if(waitKey(0)==27) //ESC to quit   
      320.             break;  
      321. #endif   
      322.   
      323.     }  
      324.   
      325.     ifs.close();  
      326.     lab_ifs.close();  
      327.   
      328.     stringstream ss;  
      329.     ss << "Total " << total << ", right " << right <<", error " << error;  
      330.     string text = ss.str();  
      331.     putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
      332.     imshow("img", img);  
      333.     waitKey(0);  
      334.   
      335.     return 0;  
      336. }  
      337.   
      338. void newSvmStudy(vector<NumTrainData>& trainData)  
      339. {  
      340.     int testCount = trainData.size();  
      341.   
      342.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
      343.     Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);  
      344.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
      345.   
      346.     for (int i= 0; i< testCount; i++)   
      347.     {   
      348.   
      349.         NumTrainData td = trainData.at(i);  
      350.         memcpy(m.data, td.data, featureLen*sizeof(float));  
      351.         normalize(m, m);  
      352.         memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));  
      353.   
      354.         res.at<unsigned int>(i, 0) = td.result;  
      355.     }  
      356.   
      357.     /////////////START SVM TRAINNING//////////////////   
      358.     CvSVM svm = CvSVM();   
      359.     CvSVMParams param;   
      360.     CvTermCriteria criteria;  
      361.   
      362.     criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);   
      363.     param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);   
      364.   
      365.     svm.train(data, res, Mat(), Mat(), param);  
      366.     svm.save( "SVM_DATA.xml" );  
      367. }  
      368.   
      369.   
      370. int newSvmPredict()  
      371. {  
      372.     CvSVM svm = CvSVM();   
      373.     svm.load( "SVM_DATA.xml" );  
      374.   
      375.     const char fileName[] = "../res/t10k-images.idx3-ubyte";  
      376.     const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";  
      377.   
      378.     ifstream lab_ifs(labelFileName, ios_base::binary);  
      379.     ifstream ifs(fileName, ios_base::binary);  
      380.   
      381.     if( ifs.fail() == true )  
      382.         return -1;  
      383.   
      384.     if( lab_ifs.fail() == true )  
      385.         return -1;  
      386.   
      387.     char magicNum[4], ccount[4], crows[4], ccols[4];  
      388.     ifs.read(magicNum, sizeof(magicNum));  
      389.     ifs.read(ccount, sizeof(ccount));  
      390.     ifs.read(crows, sizeof(crows));  
      391.     ifs.read(ccols, sizeof(ccols));  
      392.   
      393.     int count, rows, cols;  
      394.     swapBuffer(ccount);  
      395.     swapBuffer(crows);  
      396.     swapBuffer(ccols);  
      397.   
      398.     memcpy(&count, ccount, sizeof(count));  
      399.     memcpy(&rows, crows, sizeof(rows));  
      400.     memcpy(&cols, ccols, sizeof(cols));  
      401.   
      402.     Mat src = Mat::zeros(rows, cols, CV_8UC1);  
      403.     Mat temp = Mat::zeros(8, 8, CV_8UC1);  
      404.     Mat m = Mat::zeros(1, featureLen, CV_32FC1);  
      405.     Mat img, dst;  
      406.   
      407.     //Just skip label header   
      408.     lab_ifs.read(magicNum, sizeof(magicNum));  
      409.     lab_ifs.read(ccount, sizeof(ccount));  
      410.   
      411.     char label = 0;  
      412.     Scalar templateColor(255, 0, 0);  
      413.   
      414.     NumTrainData rtd;  
      415.   
      416.     int right = 0, error = 0, total = 0;  
      417.     int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;  
      418.     while(ifs.good())  
      419.     {  
      420.         //Read label   
      421.         lab_ifs.read(&label, 1);  
      422.         label = label + '0';  
      423.   
      424.         //Read data   
      425.         ifs.read((char*)src.data, rows * cols);  
      426.         GetROI(src, dst);  
      427.   
      428.         //Too small to watch   
      429.         img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);  
      430.         resize(dst, img, img.size());  
      431.   
      432.         rtd.result = label;  
      433.         resize(dst, temp, temp.size());  
      434.         //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);   
      435.         for(int i = 0; i<8; i++)  
      436.         {  
      437.             for(int j = 0; j<8; j++)  
      438.             {  
      439.                     m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);  
      440.             }  
      441.         }  
      442.   
      443.         if(total >= count)  
      444.             break;  
      445.   
      446.         normalize(m, m);  
      447.         char ret = (char)svm.predict(m);   
      448.   
      449.         if(ret == label)  
      450.         {  
      451.             right++;  
      452.             if(total <= 5000)  
      453.                 right_1++;  
      454.             else  
      455.                 right_2++;  
      456.         }  
      457.         else  
      458.         {  
      459.             error++;  
      460.             if(total <= 5000)  
      461.                 error_1++;  
      462.             else  
      463.                 error_2++;  
      464.         }  
      465.   
      466.         total++;  
      467.   
      468. #if(SHOW_PROCESS)   
      469.         stringstream ss;  
      470.         ss << "Number " << label << ", predict " << ret;  
      471.         string text = ss.str();  
      472.         putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
      473.   
      474.         imshow("img", img);  
      475.         if(waitKey(0)==27) //ESC to quit   
      476.             break;  
      477. #endif   
      478.   
      479.     }  
      480.   
      481.     ifs.close();  
      482.     lab_ifs.close();  
      483.   
      484.     stringstream ss;  
      485.     ss << "Total " << total << ", right " << right <<", error " << error;  
      486.     string text = ss.str();  
      487.     putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);  
      488.     imshow("img", img);  
      489.     waitKey(0);  
      490.   
      491.     return 0;  
      492. }  
      493.   
      494. int main( int argc, char *argv[] )  
      495. {  
      496. #if(ON_STUDY)   
      497.     int maxCount = 60000;  
      498.     ReadTrainData(maxCount);  
      499.   
      500.     //newRtStudy(buffer);   
      501.     newSvmStudy(buffer);  
      502. #else   
      503.     //newRtPredict();   
      504.     newSvmPredict();  
      505. #endif   
      506.     return 0;  
      507. }
      508. //from: http://blog.csdn.net/yangtrees/article/details/7458466
  • 相关阅读:
    MQTT入门1 -- mosquitto 安装
    利用wireshark抓取TCP的整个过程分析。
    ARM Linux驱动篇 学习温度传感器ds18b20的驱动编写过程
    移植ARM linux下远程连接工具dropbear
    飞凌2440开发板制作路由器
    基于视觉寻迹的寻路算法
    Linux I2C驱动架构
    Linux 设备树学习——基于i2c总线分析
    Linux SPI驱动学习——注册匹配
    从Linux内核LED驱动来理解字符设备驱动开发流程
  • 原文地址:https://www.cnblogs.com/GarfieldEr007/p/5401933.html
Copyright © 2011-2022 走看看