zoukankan      html  css  js  c++  java
  • 高斯牛顿迭代(GaussNewton)代码实现

    #include <cstdio>
    #include <vector>
    #include <iostream>
    #include <opencv2/core/core.hpp>
     
    using namespace std;
    using namespace cv;
     
    const double DELTAX = 1e-5;
    const int MAXCOUNT = 100;
     
    double function(const Mat &input, const Mat params)
    {
        //给定函数已知x求y
        double a = params.at<double>(0,0);
        
        double b = params.at<double>(1,0);
        
        double c = params.at<double>(2,0);
    
        double x = input.at<double>(0,0);
     
        return exp( a*x*x + b*x + c );
    }
     
    
    double derivative(double(*function)(const Mat &input, const Mat params), const Mat &input, const Mat params, int n)
    {
        // 用增加分量的方式求导数 
        Mat params1 = params.clone();
        Mat params2 = params.clone();
     
        params1.at<double>(n,0) -= DELTAX;
        params2.at<double>(n,0) += DELTAX;
     
        double y1 = function(input, params1);
        double y2 = function(input, params2);
     
        double deri = (y2 - y1) / (2*DELTAX);
     
        return deri;
    }
     
    void gaussNewton(double(*function)(const Mat &input, const Mat ms), const Mat &inputs, const Mat &outputs, Mat params)
    {
        int num_estimates = inputs.rows;
        int num_params = params.rows;
     
        Mat r(num_estimates, 1, CV_64F);           // 残差
        Mat Jf(num_estimates, num_params, CV_64F); // 雅克比矩阵
        Mat input(1, 1, CV_64F);       
     
        double lsumR = 0;
     
        for(int i = 0; i < MAXCOUNT; i++) 
        {
            double sumR = 0;
     
            for(int j = 0; j < num_estimates; j++) 
            {
                input.at<double>(0,0) = inputs.at<double>(j,0);
    
                r.at<double>(j,0) = outputs.at<double>(j,0) - function(input, params);// 计算残差矩阵
     
                sumR += fabs(r.at<double>(j,0)); // 残差累加
     
                for(int k = 0; k < num_params; k++) 
                {
                    Jf.at<double>(j,k) = derivative(function, input, params, k); // 根据新参数重新计算雅克比矩阵
                }
            }
     
            sumR /= num_estimates; //均残差
    
            if(fabs(sumR - lsumR) < 1e-8) //均残差足够小达到收敛
            {
                break;
            }
     
            Mat delta = ((Jf.t()*Jf)).inv() * Jf.t()*r;// ((Jf.t()*Jf)) 近似黑塞矩阵
            params += delta;
            lsumR = sumR;
        }
    }
    
    int main()
    {
        // F = exp ( a*x*x + b*x + c )
    
        int num_params = 3;
        Mat params(num_params, 1, CV_64F);
    
         //abc参数的实际值
        params.at<double>(0,0) = 1.0; //a
        params.at<double>(1,0) = 2.0; //b
        params.at<double>(2,0) = 1.0; //c  
        
        cout<<"real("<<"a:"<< params.at<double>(0,0) <<" b:"<< params.at<double>(1,0) << " c:"<< params.at<double>(2,0) << ")"<< endl;
    
        int N = 100;                        
        
        double w_sigma = 1.0;               // 噪声Sigma值
        
        cv::RNG rng;                        // OpenCV随机数产生器
        
        Mat estimate_x(N, 1, CV_64F);
        Mat estimate_y(N, 1, CV_64F);
        
    
        for ( int i = 0; i < N; i++ )
        {
            double x = i/100.0;
            estimate_x.at<double>(i,0) = x; 
            Mat paramX(1, 1, CV_64F);
            paramX.at<double>(0,0) = x;
            estimate_y.at<double>(i,0) = function(paramX, params) + rng.gaussian ( w_sigma );
        }
        
        //abc参数的初值
        params.at<double>(0,0) = 0; //a
        params.at<double>(1,0) = 0; //b
        params.at<double>(2,0) = 0; //c
    
        cout<<"init("<<"a:"<< params.at<double>(0,0) <<" b:"<< params.at<double>(1,0) << " c:"<< params.at<double>(2,0) << ")"<< endl;
    
        gaussNewton(function, estimate_x, estimate_y, params);
     
        cout<<"result("<<"a:"<< params.at<double>(0,0) <<" b:"<< params.at<double>(1,0) << " c:"<< params.at<double>(2,0) << ")"<< endl;
     
        return 0;
    }
    # Project: GaussNewtonDemo
    #
    # All rights reserved.
    
    cmake_minimum_required( VERSION 2.6 )
    cmake_policy( SET CMP0004 OLD )
    
    ### initialize project ###########################################################################################
    
    SET(CMAKE_BUILD_TYPE "Debug")
    SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
    SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")
    
    
    project(GaussNewtonDemo)
    find_package(Eigen3 REQUIRED)
    find_package(OpenCV REQUIRED)
    
    set(CMAKE_INSTALL_PREFIX /usr)
    
    
    set(BUILD_SHARED_LIBS on)
    
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -O0")
    include_directories(
            ${EIGEN3_INCLUDE_DIR}
            ${OpenCV_INCLUDE_DIR})
    
    add_definitions( "-DPREFIX="${CMAKE_INSTALL_PREFIX}""        )
    
    ### global default options #######################################################################################
    
    set(SOURCES 
        main.cpp
    )
    
    add_executable(GaussNewtonDemo ${SOURCES})
    
    TARGET_LINK_LIBRARIES( GaussNewtonDemo
        ${OpenCV_LIBS} )
  • 相关阅读:
    AIBigKaldi(二)| Kaldi的I/O机制(源码解析)
    OfficialKaldi(十四)| 从命令行角度来看Kaldi的 I / O
    GNU Make函数、变量、指令
    C/C++编码规范(google)
    [English]precede, be preceded by
    视频压缩技术、I帧、P帧、B帧
    SMB
    printf占位符
    使用 Yocto Project 构建自定义嵌入式 Linux 发行版
    gcc fpic fPIC
  • 原文地址:https://www.cnblogs.com/yueyangtze/p/13680959.html
Copyright © 2011-2022 走看看