zoukankan      html  css  js  c++  java
  • 简单HOG+SVM mnist手写数字分类

    使用工具 :VS2013 + OpenCV 3.1

    数据集:minst

    训练数据:60000张  测试数据:10000张  输出模型:HOG_SVM_DATA.xml  

    数据准备

    train-images-idx3-ubyte.gz:  training set images (9912422 bytes) 
    train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) 
    t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes) 
    t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)

    首先我们利用matlab将数据转换成 .bmp 图片格式

    fid_image=fopen('train-images.idx3-ubyte','r');
    fid_label=fopen('train-labels.idx1-ubyte','r');
    % Read the first 16 Bytes
    magicnumber=fread(fid_image,4);
    size=fread(fid_image,4);
    row=fread(fid_image,4);
    col=fread(fid_image,4);
    % Read the first 8 Bytes
    extra=fread(fid_label,8);
    % Read labels related to images
    imageIndex=fread(fid_label);
    Num=length(imageIndex);
    % Count repeat times of 0 to 9
    cnt=zeros(1,10);
    for k=1:Num
        image=fread(fid_image,[max(row),max(col)]);     % Get image data
        val=imageIndex(k);      % Get value of image
        for i=0:9
            if val==i
                cnt(val+1)=cnt(val+1)+1;
            end
        end
        if cnt(val+1)<10
            str=[num2str(val),'_000',num2str(cnt(val+1)),'.bmp'];
        elseif cnt(val+1)<100
            str=[num2str(val),'_00',num2str(cnt(val+1)),'.bmp'];
        elseif cnt(val+1)<1000
            str=[num2str(val),'_0',num2str(cnt(val+1)),'.bmp'];
        else
            str=[num2str(val),'_',num2str(cnt(val+1)),'.bmp'];
        end
        imwrite(image',str);
    end
    fclose(fid_image);
    fclose(fid_label);

    然后使用cmd指令写入图片路径: dir /b/s/p/w *.bmp > num.txt  添加标签,如下图

    然后打乱样本顺序。

    训练

    int main0()
    {
        vector<string> img_path;//输入文件名变量     
        vector<int> img_catg;
        int nLine = 0;
        string line;
        size_t pos;
        ifstream svm_data("./train-images/random.txt");//训练样本图片的路径都写在这个txt文件中,使用bat批处理文件可以得到这个txt文件       
        unsigned long n;
        while (svm_data)//将训练样本文件依次读取进来      
        {
            if (getline(svm_data, line))
            {
                nLine++;
                pos = line.find_last_of(' ');
                img_path.push_back(line.substr(0, pos));//图像路径      
                img_catg.push_back(atoi(line.substr(pos + 1).c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错      
            }
        }
    
        svm_data.close();//关闭文件      
        int nImgNum = nLine; //nImgNum是样本数量,只有文本行数的一半,另一半是标签         
        cv::Mat data_mat(nImgNum, 324, CV_32FC1);//第二个参数,即矩阵的列是由下面的descriptors的大小决定的,可以由descriptors.size()得到,且对于不同大小的输入训练图片,这个值是不同的    
        data_mat.setTo(cv::Scalar(0));
        //类型矩阵,存储每个样本的类型标志      
        cv::Mat res_mat(nImgNum, 1, CV_32S);
        res_mat.setTo(cv::Scalar(0));
        cv::Mat src;
        cv::Mat trainImg(cv::Size(28, 28), 8, 3);//需要分析的图片,这里默认设定图片是28*28大小,所以上面定义了324,如果要更改图片大小,可以先用debug查看一下descriptors是多少,然后设定好再运行      
    
        //处理HOG特征    
        for (string::size_type i = 0; i != img_path.size(); i++)
        {
            src = cv::imread(img_path[i].c_str(), 1);
            if (src.data == NULL)//if (src == NULL)
            {
                cout << " can not load the image: " << img_path[i].c_str() << endl;
                continue;
            }
    
            //cout << " 处理: " << img_path[i].c_str() << endl;
    
            cv::resize(src, trainImg, trainImg.size()); 
            cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(28, 28), cv::Size(14, 14), cv::Size(7, 7), cv::Size(7, 7), 9);
            vector<float>descriptors;//存放结果       
            hog->compute(trainImg, descriptors, cv::Size(1, 1), cv::Size(0, 0)); //Hog特征计算        
            //cout << "HOG dims: " << descriptors.size() << endl;
            n = 0;
            for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
            {
                //cvmSet(data_mat, i, n, *iter);
                data_mat.at<float>(i, n) = *iter;//存储HOG特征  
                n++;
            }
            //cvmSet(res_mat, i, 0, img_catg[i]);
            res_mat.at<int>(i, 0) = img_catg[i];
            //cout << " 处理完毕: " << img_path[i].c_str() << " " << img_catg[i] << endl;
        }
        cout << "computed features!" << endl;
    
        cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();//新建一个SVM        
        svm->setType(cv::ml::SVM::C_SVC);
        svm->setKernel(cv::ml::SVM::LINEAR);
        svm->setC(1);
        //-------------------不使用参数优化-------------------------//
        svm->setTermCriteria(cv::TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
        svm->train(data_mat, cv::ml::ROW_SAMPLE, res_mat);//训练数据   
        //-------------------参数优化-------------------------//
        //svm->setTermCriteria = cv::TermCriteria(cv::TermCriteria::MAX_ITER, (int)1e7, 1e-6);
        //cv::Ptr<cv::ml::TrainData> td = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, res_mat);
        //svm->trainAuto(td, 10);
    
        //保存训练好的分类器        
        svm->save("HOG_SVM_DATA.xml");
        cout << "saved model!" << endl;
        //检测样本      
        cv::Mat test;//IplImage *test;
        char result[512];
        vector<string> img_test_path;
        vector<int> img_test_catg;
        int coorect = 0;
        ifstream img_tst("./test-images/random.txt");  //加载需要预测的图片集合,这个文本里存放的是图片全路径,不要标签  
        while (img_tst)
        {
            if (getline(img_tst, line))
            {
                pos = line.find_last_of(' ');
                img_test_catg.push_back(atoi(line.substr(pos + 1).c_str()));//atoi将字符串转换成整型,标志(0,1,2,...,9),注意这里至少要有两个类别,否则会出错      
                img_test_path.push_back(line.substr(0, pos));//图像路径      
            }
        }
        img_tst.close();
    
        ofstream predict_txt("SVM_PREDICT.txt");//把预测结果存储在这个文本中     
        for (string::size_type j = 0; j != img_test_path.size(); j++)//依次遍历所有的待检测图片      
        {
            test = cv::imread(img_test_path[j].c_str(), 1);
            if (test.data == NULL)//test == NULL
            {
                cout << " can not load the image: " << img_test_path[j].c_str() << endl;
                continue;
            }
            cv::Mat trainTempImg(cv::Size(28, 28), 8, 3);
            trainTempImg.setTo(cv::Scalar(0));
            cv::resize(test, trainTempImg, trainTempImg.size());
            cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(28, 28), cv::Size(14, 14), cv::Size(7, 7), cv::Size(7, 7), 9);
            vector<float>descriptors;//结果数组         
            hog->compute(trainTempImg, descriptors, cv::Size(1, 1), cv::Size(0, 0));
            //cout << "HOG dims: " << descriptors.size() << endl;
            cv::Mat SVMtrainMat(1, descriptors.size(), CV_32FC1);
            int n = 0;
            for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
            {
                SVMtrainMat.at<float>(0, n) = *iter;
                n++;
            }
    
            int ret = svm->predict(SVMtrainMat);//检测结果  
            if (ret == img_test_catg[j])
                coorect++;
            sprintf(result, "%s  %d
    ", img_test_path[j].c_str(), ret);
            predict_txt << result;  //输出检测结果到文本  
        }
        predict_txt.close();
        cout << coorect*100 / img_test_path.size() << "%" << endl;
        return 0;
    }

    测试

    int main(int argc, char* argv[])
    {
        cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();   
        svm = cv::ml::SVM::load("HOG_SVM_DATA.xml");;//加载训练好的xml文件,这里训练的是10K个手写数字  
        //检测样本      
        cv::Mat test;
        char result[300]; //存放预测结果   
        test = cv::imread("6.bmp", 1); //待预测图片,用系统自带的画图工具随便手写  
        if (!test.data)
        {
            MessageBox(NULL, TEXT("待预测图像不存在!"), TEXT("提示"), MB_ICONWARNING);
            return -1;
        }
        cv::Mat trainTempImg(cv::Size(28, 28), 8, 3);
        trainTempImg.setTo(cv::Scalar(0));
        cv::resize(test, trainTempImg, trainTempImg.size());
        cv::HOGDescriptor *hog = new cv::HOGDescriptor(cv::Size(28, 28), cv::Size(14, 14), cv::Size(7, 7), cv::Size(7, 7), 9);
        vector<float>descriptors;//结果数组         
        hog->compute(trainTempImg, descriptors, cv::Size(1, 1), cv::Size(0, 0));
        //cout << "HOG dims: " << descriptors.size() << endl;
        cv::Mat SVMtrainMat(1, descriptors.size(), CV_32FC1);
        int n = 0;
        for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
        {
            SVMtrainMat.at<float>(0, n) = *iter;
            n++;
        }
        int ret = svm->predict(SVMtrainMat);//检测结果  
        sprintf(result, "%d
    ", ret);
        cv::namedWindow("dst", 0);
        cv::imshow("dst", test);
        MessageBox(NULL, result, TEXT("预测结果"), MB_OK);
        return 0;
    }
  • 相关阅读:
    Flink延时监控
    FLink全链路时延—测量方式
    Linux搭建SFTP服务器
    Red Hat:USING AMQ STREAMS WITH MIRRORMAKER 2.0
    idea 搭建运行kafka 源码
    Kafka Connect Concepts
    Java IPv6相关属性preferIPv4Stack、preferIPv6Addresses介绍
    如何确定Flink反压的根源?How to identify the source of backpressure?
    如何成为 Apache 项目的 Committer
    Apache Kafka KIP 介绍
  • 原文地址:https://www.cnblogs.com/xuanyuyt/p/6405944.html
Copyright © 2011-2022 走看看