zoukankan      html  css  js  c++  java
  • 在opencv3中实现机器学习之:利用逻辑斯谛回归(logistic regression)分类

    logistic regression,注意这个单词logistic ,并不是逻辑(logic)的意思,音译过来应该是逻辑斯谛回归,或者直接叫logistic回归,并不是什么逻辑回归。大部分人都叫成逻辑回归,无奈啊。。。虽然这个算法中有回归二字,但它做的事情却并不是回归,而是分类。这个算法只能解决简单的线性二分类,在众多的机器学习分类算法中并不出众,但它能被改进为多分类,并换了另外一个名字softmax, 这可是深度学习中响当当的分类算法。因此,logistic回归瞬间也变得高大上起来。

    本文用它来进行手写数字分类。在opencv3.0中提供了一个xml文件,里面存放了40个样本,分别是20个数字0的手写体和20个数字1的手写体。本来每个数字的手写体是一张28*28的小图片,但opencv把它reshape了一下,变成了1*784 的向量,然后放在xml文件中。这个文件的位置:

    opencvsourcessamplesdatadata01.xml

    代码:

    // face_detect.cpp : 定义控制台应用程序的入口点。
    //
    
    #include "stdafx.h"
    #include "opencv2opencv.hpp"
    #include <iostream>
    using namespace std;
    using namespace cv;
    using namespace cv::ml;
    
    //将向量转化成图片矩阵并显示
    void showImage(const Mat &data, int columns, const String &name)
    {
        Mat bigImage;
        for (int i = 0; i < data.rows; ++i)
        {
            bigImage.push_back(data.row(i).reshape(0, columns));
        }
        imshow(name, bigImage.t());
    }
    
    //计算分类精度
    float calculateAccuracyPercent(const Mat &original, const Mat &predicted)
    {
        return 100 * (float)countNonZero(original == predicted) / predicted.rows;
    }
    
    int main()
    {
        const String filename = "E:\opencv\opencv\sources\samples\data\data01.xml";
    
        Mat data, labels;  //训练数据及对应标注
    
        cout << "加载数据..." << endl;
        FileStorage f;
        if (f.open(filename, FileStorage::READ))
        {
            f["datamat"] >> data;
            f["labelsmat"] >> labels;
            f.release();
        }
        else
        {
            cerr << "文件无法打开: " << filename << endl;
            return 1;
        }
        data.convertTo(data, CV_32F);  //转换成float型
        labels.convertTo(labels, CV_32F);
        cout << "读取了 " << data.rows << "行数据" << endl;
    
        Mat data_train, data_test;
        Mat labels_train, labels_test;
        //将加载进来的数据均分成两部分,一部分用于训练,一部分用于测试
        for (int i = 0; i < data.rows; i++)
        {
            if (i % 2 == 0)
            {
                data_train.push_back(data.row(i));
                labels_train.push_back(labels.row(i));
            }
            else
            {
                data_test.push_back(data.row(i));
                labels_test.push_back(labels.row(i));
            }
        }
        cout << "训练数据: " << data_train.rows << "" << endl;
        cout<<"测试数据:"<< data_test.rows <<""<< endl;
    
        // 显示样本图片
        showImage(data_train, 28, "train data");
        showImage(data_test, 28, "test data");
    
        //创建分类器并设置参数
        Ptr<LogisticRegression> lr1 = LogisticRegression::create();  
        lr1->setLearningRate(0.001);
        lr1->setIterations(10);
        lr1->setRegularization(LogisticRegression::REG_L2);
        lr1->setTrainMethod(LogisticRegression::BATCH);
        lr1->setMiniBatchSize(1);
    
        //训练分类器
        lr1->train(data_train, ROW_SAMPLE, labels_train);
    
        Mat responses;
        //预测
        lr1->predict(data_test, responses);
    
        // 展示预测结果
        cout << "原始数据 vs 预测数据:" << endl;
        labels_test.convertTo(labels_test, CV_32S);  //转换为整型
        cout << labels_test.t() << endl;
        cout << responses.t() << endl;
        cout << "accuracy: " << calculateAccuracyPercent(labels_test, responses) << "%" << endl;
    
        waitKey(0);
        return 0;
    }

    从结果显示可以看出,待测数据(test data)是20个,算法分对了19个,精度为95%.

  • 相关阅读:
    Mysql基本类型(字符串类型)——mysql之二
    MySQL 中索引的长度的限制
    MySQL索引长度限制
    WebStorm 2019 3.3 安装及破解教程附汉化教程 Jetbrains2020全系列 2020.1.2 最新激活补丁
    用Swoole4 打造高并发的PHP协程Mysql连接池
    phpsocket.io
    php并发加锁
    PHP字符串全排列算法
    php beast windows编译教程
    使用PHP-Beast加密你的PHP源代码
  • 原文地址:https://www.cnblogs.com/denny402/p/5032490.html
Copyright © 2011-2022 走看看