zoukankan      html  css  js  c++  java
  • 基于SGD、ASGD算法的SVM分类器(OpenCV案例源码train_svmsgd.cpp解读)

    此案例用于二分类问题(鼠标左键、右键点出两类点,会实时画出分界线),最终得到一条分界线(直线):f(x)=weights*x+shift

    源码不再贴出,只讲解最核心的doTrain()里的内容。参数含义翻译自ml.hpp文件。

    与SVM不同,SVMSGD不需要设置核函数。

    【参数】默认值见下述代码

    模型类型:SGD、ASGD(推荐)。随机梯度下降、平均随机梯度下降。
    边界类型:HARD_MARGIN、SOFT_MARGIN(推荐),前者用于线性可分,后者用于非线性可分
    边界规范化 lambda:推荐设为0.0001(对于SGD),0.00001(对于ASGD)。越小,异类被抛弃的越少。
    步长 gamma_0
    步长降低力度 c:推荐设置为1(对于SGD),0.75(对于ASGD)
    终止条件:TermCriteria::COUNT、TermCriteria::EPS、TermCriteria::COUNT + TermCriteria::EPS

    参数设置函数:

    setSvmsgdType()
    setMarginType()
    setMarginRegularization()
    setInitialStepSize()
    setStepDecreasingPower()

    【使用方式】

    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();//创建对象
    svmsgd->train(trainData);//训练
    svmsgd->save("MySvmsgd.xml");//保存模型
    svmsgd->load("MySvmsgd.xml");//加载模型
    svmsgd->predict(samples, responses);//预测,结果保存到responses标签中

    bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift)
    {
        //*创建SVMSGD对象
        cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); //创建SVMSGD对象
        //*设置参数,以下全是默认参数
        //svmsgd->setSvmsgdType(SVMSGD::ASGD); //模型类型
        //svmsgd->setMarginType(SVMSGD::SOFT_MARGIN); //边界类型
        //svmsgd->setMarginRegularization(0.00001); //边界规范化
        //svmsgd->setInitialStepSize(0.05);//步长
        //svmsgd->setStepDecreasingPower(0.75); //步长减弱力度
        //svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT,1000,1e-3));//终止条件,1000次迭代,0.001每次迭代的精度
        //*训练集
        cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
        //*训练
        svmsgd->train(trainData);
    
        if (svmsgd->isTrained()) //获取分界线的系数,f(x)=weights*x+shift
        {
            weights = svmsgd->getWeights();//x系数
            shift = svmsgd->getShift();//常数项
            //*保存模型
            svmsgd->save("svmsgd.xml"); //保存训练好的模型
            
            return true;
        }
        return false;
    }

    得到的xml中,weights有两个数,shift有一个数。

     f(x)=weights*x+shift,不可以理解为y=kx+b,应该理解为Ax+By+C=0。weights的两个数就是A、B,shift是C。

    Mat weights(1, 2, CV_32FC1); weights是一个1*2的向量,x也是1*2的向量(xi,xj)也就是(x,y)坐标。

    公式写全了就是:f(x)=weights1*xi+weights2*xj+shift,其实就是weights与x这两个向量的内积(对应相乘在求和)

    f(x)如果等于0,说明点在此直线上,大于0就在线的一边,小于0在线的另一边。

  • 相关阅读:
    spark 读取mongodb失败,报executor time out 和GC overhead limit exceeded 异常
    在zepplin 使用spark sql 查询mongodb的数据
    Unable to query from Mongodb from Zeppelin using spark
    spark 与zepplin 版本兼容
    kafka 新旧消费者的区别
    kafka 新生产者发送消息流程
    spark ui acl 不生效的问题分析
    python中if __name__ == '__main__': 的解析
    深入C++的new
    NSSplitView
  • 原文地址:https://www.cnblogs.com/xixixing/p/12430202.html
Copyright © 2011-2022 走看看