zoukankan      html  css  js  c++  java
  • cubic spline interpolation 概念解释和实现

    博客参考:https://blog.csdn.net/flyingleo1981/article/details/53008931

    样条插值是一种工业设计中常用的、得到平滑曲线的一种插值方法,三次样条又是其中用的较为广泛的一种。本篇介绍力求用容易理解的方式,介绍一下三次样条插值的原理,并附C语言的实现代码。

    1. 三次样条曲线原理

    假设有以下节点

    1.1 定义

    条曲线image 是一个分段定义的公式。给定n+1个数据点,共有n个区间,三次样条方程满足以下条件:

    1. 在每个分段区间image (i = 0, 1, …, n-1,x递增), image 都是一个三次多项式
    2. 满足image (i = 0, 1, …, n )
    3. image ,导数image ,二阶导数image 在[a, b]区间都是连续的,即image曲线是光滑的。

    所以n个三次多项式分段可以写作:

    image ,i = 0, 1, …, n-1

    其中ai, bi, ci, di代表4n个未知系数。

    1.2 求解

    已知条件

    1. n+1个数据点[xi, yi], i = 0, 1, …, n
    2. 每一分段都是三次多项式函数曲线
    3. 节点达到二阶连续
    4. 左右两端点处特性(自然边界,固定边界,非节点边界)

    根据定点,求出每段样条曲线方程中的系数,即可得到每段曲线的具体表达式。

    • 插值和连续性:

    image, 其中 i = 0, 1, …, n-1

    • 微分连续性:

    image , 其中 i = 0, 1, …, n-2

    • 样条曲线的微分式:

     将步长 带入样条曲线的条件:

    1.  由image (i = 0, 1, …, n-1) 推出: 
    2. image (i = 0, 1, …, n-1) 推出:  
    3. 由 image (i = 0, 1, …, n-2) 推出: 

      由此可得:

      

       4. 由 image (i = 0, 1, …, n-2) 推出: 

    • image ,则 image 可写为:image ,推出: image
    • 将ci, di带入 image 可得:
    •  将bi, ci, di带入image (i = 0, 1, …, n-2)可得:

    端点条件

     由i的取值范围可知,共有n-1个公式, 但却有n+1个未知量m 。要想求解该方程组,还需另外两个式子。所以需要对两端点x0和xn的微分加些限制。 选择不是唯一的,3种比较常用的限制如下。

    1. 自由边界(Natural)

    首尾两端没有受到任何让它们弯曲的力,即image 。具体表示为image 和 image. 则要求解的方程组可写为:

      2. 固定边界(Clamped)

    首尾两端点的微分值是被指定的,这里分别定为A和B。则可以推出

     

     将上述两个公式带入方程组,新的方程组左侧为

     

       3. 非节点边界(Not-A-Knot)

    指定样条曲线的三次微分匹配,即

     和 

    根据image 和image ,则上述条件变为:  和 

     新的方程组系数矩阵可写为:

     右下图可以看出不同的端点边界对样条曲线的影响:

     

    1.3 算法总结

    假定有n+1个数据节点: 

    1.  计算步长image (i = 0, 1, …, n-1)
    2. 将数据节点和指定的首位端点条件带入矩阵方程
    3. 解矩阵方程,求得二次微分值image。该矩阵为三对角矩阵,具体求法参见我的上篇文章:三对角矩阵的求解
    4. 计算样条曲线的系数:

     其中i = 0, 1, …, n-1

      5. 在每个子区间image 中,创建方程 

    2. C++ 语言实现

    C++语言写了一个三次样条插值(自然边界)函数,代码为 Udacity Path Planning 课程中使用的 simple cubic spline interpolation library without external 文件。

    /*
     * spline.h
     *
     * simple cubic spline interpolation library without external
     * dependencies
     *
     * ---------------------------------------------------------------------
     * Copyright (C) 2011, 2014 Tino Kluge (ttk448 at gmail.com)
     *
     *  This program is free software; you can redistribute it and/or
     *  modify it under the terms of the GNU General Public License
     *  as published by the Free Software Foundation; either version 2
     *  of the License, or (at your option) any later version.
     *
     *  This program is distributed in the hope that it will be useful,
     *  but WITHOUT ANY WARRANTY; without even the implied warranty of
     *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     *  GNU General Public License for more details.
     *
     *  You should have received a copy of the GNU General Public License
     *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
     * ---------------------------------------------------------------------
     *
     */
    
    
    #ifndef TK_SPLINE_H
    #define TK_SPLINE_H
    
    #include <cstdio>
    #include <cassert>
    #include <vector>
    #include <algorithm>
    
    // unnamed namespace only because the implementation is in this
    // header file and we don't want to export symbols to the obj files
    namespace
    {
    
    namespace tk
    {
    
    // band matrix solver
    class band_matrix
    {
    private:
        std::vector< std::vector<double> > m_upper;  // upper band
        std::vector< std::vector<double> > m_lower;  // lower band
    public:
        band_matrix() {};                             // constructor
        band_matrix(int dim, int n_u, int n_l);       // constructor
        ~band_matrix() {};                            // destructor
        void resize(int dim, int n_u, int n_l);      // init with dim,n_u,n_l
        int dim() const;                             // matrix dimension
        int num_upper() const
        {
            return m_upper.size()-1;
        }
        int num_lower() const
        {
            return m_lower.size()-1;
        }
        // access operator
        double & operator () (int i, int j);            // write
        double   operator () (int i, int j) const;      // read
        // we can store an additional diogonal (in m_lower)
        double& saved_diag(int i);
        double  saved_diag(int i) const;
        void lu_decompose();
        std::vector<double> r_solve(const std::vector<double>& b) const;
        std::vector<double> l_solve(const std::vector<double>& b) const;
        std::vector<double> lu_solve(const std::vector<double>& b,
                                     bool is_lu_decomposed=false);
    
    };
    
    // spline interpolation
    class spline
    {
    public:
        enum bd_type
    	{
            first_deriv = 1,
            second_deriv = 2
        };
    
    private:
        std::vector<double> m_x,m_y;            // x,y coordinates of points
        // interpolation parameters
        // f(x) = a*(x-x_i)^3 + b*(x-x_i)^2 + c*(x-x_i) + y_i
        std::vector<double> m_a,m_b,m_c;        // spline coefficients
        double  m_b0, m_c0;                     // for left extrapol
        bd_type m_left, m_right;
        double  m_left_value, m_right_value;
        bool    m_force_linear_extrapolation;
    
    public:
        // set default boundary condition to be zero curvature at both ends
        spline(): m_left(second_deriv), m_right(second_deriv), m_left_value(0.0), m_right_value(0.0), m_force_linear_extrapolation(false)
        {
            ;
        }
    
        // optional, but if called it has to come be before set_points()
        void set_boundary(bd_type left, double left_value,
                          bd_type right, double right_value,
                          bool force_linear_extrapolation=false);
        void set_points(const std::vector<double>& x,  const std::vector<double>& y, bool cubic_spline=true);
        double operator() (double x) const;
    };
    
    // ---------------------------------------------------------------------
    // implementation part, which could be separated into a cpp file
    // ---------------------------------------------------------------------
    
    // band_matrix implementation
    // -------------------------
    
    band_matrix::band_matrix(int dim, int n_u, int n_l)
    {
        resize(dim, n_u, n_l);
    }
    void band_matrix::resize(int dim, int n_u, int n_l)
    {
        assert(dim>0);
        assert(n_u>=0);
        assert(n_l>=0);
        m_upper.resize(n_u+1);
        m_lower.resize(n_l+1);
        for(size_t i=0; i<m_upper.size(); i++) {
            m_upper[i].resize(dim);
        }
        for(size_t i=0; i<m_lower.size(); i++) {
            m_lower[i].resize(dim);
        }
    }
    int band_matrix::dim() const
    {
        if(m_upper.size()>0) {
            return m_upper[0].size();
        } else {
            return 0;
        }
    }
    
    
    // defines the new operator (), so that we can access the elements
    // by A(i,j), index going from i=0,...,dim()-1
    double & band_matrix::operator () (int i, int j)
    {
        int k=j-i;       // what band is the entry
        assert( (i>=0) && (i<dim()) && (j>=0) && (j<dim()) );
        assert( (-num_lower()<=k) && (k<=num_upper()) );
        // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
        if(k>=0)   return m_upper[k][i];
        else	    return m_lower[-k][i];
    }
    double band_matrix::operator () (int i, int j) const
    {
        int k=j-i;       // what band is the entry
        assert( (i>=0) && (i<dim()) && (j>=0) && (j<dim()) );
        assert( (-num_lower()<=k) && (k<=num_upper()) );
        // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
        if(k>=0)   return m_upper[k][i];
        else	    return m_lower[-k][i];
    }
    // second diag (used in LU decomposition), saved in m_lower
    double band_matrix::saved_diag(int i) const
    {
        assert( (i>=0) && (i<dim()) );
        return m_lower[0][i];
    }
    double & band_matrix::saved_diag(int i)
    {
        assert( (i>=0) && (i<dim()) );
        return m_lower[0][i];
    }
    
    // LR-Decomposition of a band matrix
    void band_matrix::lu_decompose()
    {
        int  i_max,j_max;
        int  j_min;
        double x;
    
        // preconditioning
        // normalize column i so that a_ii=1
        for(int i=0; i<this->dim(); i++) {
            assert(this->operator()(i,i)!=0.0);
            this->saved_diag(i)=1.0/this->operator()(i,i);
            j_min=std::max(0,i-this->num_lower());
            j_max=std::min(this->dim()-1,i+this->num_upper());
            for(int j=j_min; j<=j_max; j++) {
                this->operator()(i,j) *= this->saved_diag(i);
            }
            this->operator()(i,i)=1.0;          // prevents rounding errors
        }
    
        // Gauss LR-Decomposition
        for(int k=0; k<this->dim(); k++) {
            i_max=std::min(this->dim()-1,k+this->num_lower());  // num_lower not a mistake!
            for(int i=k+1; i<=i_max; i++) {
                assert(this->operator()(k,k)!=0.0);
                x=-this->operator()(i,k)/this->operator()(k,k);
                this->operator()(i,k)=-x;                         // assembly part of L
                j_max=std::min(this->dim()-1,k+this->num_upper());
                for(int j=k+1; j<=j_max; j++) {
                    // assembly part of R
                    this->operator()(i,j)=this->operator()(i,j)+x*this->operator()(k,j);
                }
            }
        }
    }
    // solves Ly=b
    std::vector<double> band_matrix::l_solve(const std::vector<double>& b) const
    {
        assert( this->dim()==(int)b.size() );
        std::vector<double> x(this->dim());
        int j_start;
        double sum;
        for(int i=0; i<this->dim(); i++) {
            sum=0;
            j_start=std::max(0,i-this->num_lower());
            for(int j=j_start; j<i; j++) sum += this->operator()(i,j)*x[j];
            x[i]=(b[i]*this->saved_diag(i)) - sum;
        }
        return x;
    }
    // solves Rx=y
    std::vector<double> band_matrix::r_solve(const std::vector<double>& b) const
    {
        assert( this->dim()==(int)b.size() );
        std::vector<double> x(this->dim());
        int j_stop;
        double sum;
        for(int i=this->dim()-1; i>=0; i--) {
            sum=0;
            j_stop=std::min(this->dim()-1,i+this->num_upper());
            for(int j=i+1; j<=j_stop; j++) sum += this->operator()(i,j)*x[j];
            x[i]=( b[i] - sum ) / this->operator()(i,i);
        }
        return x;
    }
    
    std::vector<double> band_matrix::lu_solve(const std::vector<double>& b,
            bool is_lu_decomposed)
    {
        assert( this->dim()==(int)b.size() );
        std::vector<double>  x,y;
        if(is_lu_decomposed==false) {
            this->lu_decompose();
        }
        y=this->l_solve(b);
        x=this->r_solve(y);
        return x;
    }
    
    // spline implementation
    // -----------------------
    void spline::set_boundary(spline::bd_type left, double left_value,
                              spline::bd_type right, double right_value,
                              bool force_linear_extrapolation)
    {
        assert(m_x.size()==0);          // set_points() must not have happened yet
        m_left=left;
        m_right=right;
        m_left_value=left_value;
        m_right_value=right_value;
        m_force_linear_extrapolation=force_linear_extrapolation;
    }
    
    void spline::set_points(const std::vector<double>& x,const std::vector<double>& y, bool cubic_spline)
    {
        assert(x.size()==y.size());
        assert(x.size()>2);
        m_x=x;
        m_y=y;
        int n = x.size();
        
    	// TODO: maybe sort x and y, rather than returning an error
        for(int i=0; i<n-1; i++)
    	{
            assert(m_x[i] < m_x[i+1]);
        }
    
        if(cubic_spline==true) 
    	{ 
    		// cubic spline interpolation
            // setting up the matrix and right hand side of the equation system
            // for the parameters b[]
            band_matrix A(n,1,1);
            std::vector<double>  rhs(n);
            for(int i=1; i<n-1; i++)
    		{
                A(i,i-1)=1.0/3.0*(x[i]-x[i-1]);
                A(i,i)=2.0/3.0*(x[i+1]-x[i-1]);
                A(i,i+1)=1.0/3.0*(x[i+1]-x[i]);
                rhs[i]=(y[i+1]-y[i])/(x[i+1]-x[i]) - (y[i]-y[i-1])/(x[i]-x[i-1]);
            }
            // boundary conditions
            if(m_left == spline::second_deriv) 
    		{
                // 2*b[0] = f''
                A(0,0)=2.0;
                A(0,1)=0.0;
                rhs[0]=m_left_value;
            } 
    		else if(m_left == spline::first_deriv) 
    		{
                // c[0] = f', needs to be re-expressed in terms of b:
                // (2b[0]+b[1])(x[1]-x[0]) = 3 ((y[1]-y[0])/(x[1]-x[0]) - f')
                A(0,0)=2.0*(x[1]-x[0]);
                A(0,1)=1.0*(x[1]-x[0]);
                rhs[0]=3.0*((y[1]-y[0])/(x[1]-x[0])-m_left_value);
            } 
    		else
    		{
                assert(false);
            }
           
    		if(m_right == spline::second_deriv)
    		{
                // 2*b[n-1] = f''
                A(n-1,n-1)=2.0;
                A(n-1,n-2)=0.0;
                rhs[n-1]=m_right_value;
            } 
    		else if(m_right == spline::first_deriv)
    		{
                // c[n-1] = f', needs to be re-expressed in terms of b:
                // (b[n-2]+2b[n-1])(x[n-1]-x[n-2])
                // = 3 (f' - (y[n-1]-y[n-2])/(x[n-1]-x[n-2]))
                A(n-1,n-1)=2.0*(x[n-1]-x[n-2]);
                A(n-1,n-2)=1.0*(x[n-1]-x[n-2]);
                rhs[n-1]=3.0*(m_right_value-(y[n-1]-y[n-2])/(x[n-1]-x[n-2]));
            } else {
                assert(false);
            }
    
            // solve the equation system to obtain the parameters b[]
            m_b=A.lu_solve(rhs);
    
            // calculate parameters a[] and c[] based on b[]
            m_a.resize(n);
            m_c.resize(n);
            for(int i=0; i<n-1; i++)
    		{
                m_a[i]=1.0/3.0*(m_b[i+1]-m_b[i])/(x[i+1]-x[i]);
                m_c[i]=(y[i+1]-y[i])/(x[i+1]-x[i])
                       - 1.0/3.0*(2.0*m_b[i]+m_b[i+1])*(x[i+1]-x[i]);
            }
        } 
    	else 
    	{ 
    		// linear interpolation
            m_a.resize(n);
            m_b.resize(n);
            m_c.resize(n);
            for(int i=0; i<n-1; i++)
    		{
                m_a[i]=0.0;
                m_b[i]=0.0;
                m_c[i]=(m_y[i+1]-m_y[i])/(m_x[i+1]-m_x[i]);
            }
        }
    
        // for left extrapolation coefficients
        m_b0 = (m_force_linear_extrapolation==false) ? m_b[0] : 0.0;
        m_c0 = m_c[0];
    
        // for the right extrapolation coefficients
        // f_{n-1}(x) = b*(x-x_{n-1})^2 + c*(x-x_{n-1}) + y_{n-1}
        double h=x[n-1]-x[n-2];
        // m_b[n-1] is determined by the boundary condition
        m_a[n-1]=0.0;
        m_c[n-1]=3.0*m_a[n-2]*h*h+2.0*m_b[n-2]*h+m_c[n-2];   // = f'_{n-2}(x_{n-1})
        if(m_force_linear_extrapolation==true)
            m_b[n-1]=0.0;
    }
    
    double spline::operator() (double x) const
    {
        size_t n=m_x.size();
        // find the closest point m_x[idx] < x, idx=0 even if x<m_x[0]
        std::vector<double>::const_iterator it;
        it=std::lower_bound(m_x.begin(),m_x.end(),x);
        int idx=std::max( int(it-m_x.begin())-1, 0);
    
        double h=x-m_x[idx];
        double interpol;
        if(x<m_x[0])
    	{
            // extrapolation to the left
            interpol=(m_b0*h + m_c0)*h + m_y[0];
        } 
    	else if(x>m_x[n-1])
    	{
            // extrapolation to the right
            interpol=(m_b[n-1]*h + m_c[n-1])*h + m_y[n-1];
        } 
    	else 
    	{
            // interpolation
            interpol=((m_a[idx]*h + m_b[idx])*h + m_c[idx])*h + m_y[idx];
        }
        return interpol;
    }
    
    } // namespace tk
    
    } // namespace
    
    #endif /* TK_SPLINE_H */
    

    显示效果如下

  • 相关阅读:
    AJAX以及XMLHttpRequest
    理解Promise对象
    HTTP报文整理
    前端 — URL、URI、URN概念和区别整理,以及URL语法规则
    gulp与webpack的区别
    Sass和less的区别是什么?用哪个好
    Vue3.0 && Vue3.0初体验 一
    Promise入门详解和基本用法
    js对象方法大全
    hash模式和history模式 实现原理及区别
  • 原文地址:https://www.cnblogs.com/flyinggod/p/12826647.html
Copyright © 2011-2022 走看看