zoukankan      html  css  js  c++  java
  • cubic spline interpolation 概念解释和实现

    博客参考:https://blog.csdn.net/flyingleo1981/article/details/53008931

    样条插值是一种工业设计中常用的、得到平滑曲线的一种插值方法,三次样条又是其中用的较为广泛的一种。本篇介绍力求用容易理解的方式,介绍一下三次样条插值的原理,并附C语言的实现代码。

    1. 三次样条曲线原理

    假设有以下节点

    1.1 定义

    条曲线image 是一个分段定义的公式。给定n+1个数据点,共有n个区间,三次样条方程满足以下条件:

    1. 在每个分段区间image (i = 0, 1, …, n-1,x递增), image 都是一个三次多项式
    2. 满足image (i = 0, 1, …, n )
    3. image ,导数image ,二阶导数image 在[a, b]区间都是连续的,即image曲线是光滑的。

    所以n个三次多项式分段可以写作:

    image ,i = 0, 1, …, n-1

    其中ai, bi, ci, di代表4n个未知系数。

    1.2 求解

    已知条件

    1. n+1个数据点[xi, yi], i = 0, 1, …, n
    2. 每一分段都是三次多项式函数曲线
    3. 节点达到二阶连续
    4. 左右两端点处特性(自然边界,固定边界,非节点边界)

    根据定点,求出每段样条曲线方程中的系数,即可得到每段曲线的具体表达式。

    • 插值和连续性:

    image, 其中 i = 0, 1, …, n-1

    • 微分连续性:

    image , 其中 i = 0, 1, …, n-2

    • 样条曲线的微分式:

     将步长 带入样条曲线的条件:

    1.  由image (i = 0, 1, …, n-1) 推出: 
    2. image (i = 0, 1, …, n-1) 推出:  
    3. 由 image (i = 0, 1, …, n-2) 推出: 

      由此可得:

      

       4. 由 image (i = 0, 1, …, n-2) 推出: 

    • image ,则 image 可写为:image ,推出: image
    • 将ci, di带入 image 可得:
    •  将bi, ci, di带入image (i = 0, 1, …, n-2)可得:

    端点条件

     由i的取值范围可知,共有n-1个公式, 但却有n+1个未知量m 。要想求解该方程组,还需另外两个式子。所以需要对两端点x0和xn的微分加些限制。 选择不是唯一的,3种比较常用的限制如下。

    1. 自由边界(Natural)

    首尾两端没有受到任何让它们弯曲的力,即image 。具体表示为image 和 image. 则要求解的方程组可写为:

      2. 固定边界(Clamped)

    首尾两端点的微分值是被指定的,这里分别定为A和B。则可以推出

     

     将上述两个公式带入方程组,新的方程组左侧为

     

       3. 非节点边界(Not-A-Knot)

    指定样条曲线的三次微分匹配,即

     和 

    根据image 和image ,则上述条件变为:  和 

     新的方程组系数矩阵可写为:

     右下图可以看出不同的端点边界对样条曲线的影响:

     

    1.3 算法总结

    假定有n+1个数据节点: 

    1.  计算步长image (i = 0, 1, …, n-1)
    2. 将数据节点和指定的首位端点条件带入矩阵方程
    3. 解矩阵方程,求得二次微分值image。该矩阵为三对角矩阵,具体求法参见我的上篇文章:三对角矩阵的求解
    4. 计算样条曲线的系数:

     其中i = 0, 1, …, n-1

      5. 在每个子区间image 中,创建方程 

    2. C++ 语言实现

    C++语言写了一个三次样条插值(自然边界)函数,代码为 Udacity Path Planning 课程中使用的 simple cubic spline interpolation library without external 文件。

    /*
     * spline.h
     *
     * simple cubic spline interpolation library without external
     * dependencies
     *
     * ---------------------------------------------------------------------
     * Copyright (C) 2011, 2014 Tino Kluge (ttk448 at gmail.com)
     *
     *  This program is free software; you can redistribute it and/or
     *  modify it under the terms of the GNU General Public License
     *  as published by the Free Software Foundation; either version 2
     *  of the License, or (at your option) any later version.
     *
     *  This program is distributed in the hope that it will be useful,
     *  but WITHOUT ANY WARRANTY; without even the implied warranty of
     *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     *  GNU General Public License for more details.
     *
     *  You should have received a copy of the GNU General Public License
     *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
     * ---------------------------------------------------------------------
     *
     */
    
    
    #ifndef TK_SPLINE_H
    #define TK_SPLINE_H
    
    #include <cstdio>
    #include <cassert>
    #include <vector>
    #include <algorithm>
    
    // unnamed namespace only because the implementation is in this
    // header file and we don't want to export symbols to the obj files
    namespace
    {
    
    namespace tk
    {
    
    // band matrix solver
    class band_matrix
    {
    private:
        std::vector< std::vector<double> > m_upper;  // upper band
        std::vector< std::vector<double> > m_lower;  // lower band
    public:
        band_matrix() {};                             // constructor
        band_matrix(int dim, int n_u, int n_l);       // constructor
        ~band_matrix() {};                            // destructor
        void resize(int dim, int n_u, int n_l);      // init with dim,n_u,n_l
        int dim() const;                             // matrix dimension
        int num_upper() const
        {
            return m_upper.size()-1;
        }
        int num_lower() const
        {
            return m_lower.size()-1;
        }
        // access operator
        double & operator () (int i, int j);            // write
        double   operator () (int i, int j) const;      // read
        // we can store an additional diogonal (in m_lower)
        double& saved_diag(int i);
        double  saved_diag(int i) const;
        void lu_decompose();
        std::vector<double> r_solve(const std::vector<double>& b) const;
        std::vector<double> l_solve(const std::vector<double>& b) const;
        std::vector<double> lu_solve(const std::vector<double>& b,
                                     bool is_lu_decomposed=false);
    
    };
    
    // spline interpolation
    class spline
    {
    public:
        enum bd_type
    	{
            first_deriv = 1,
            second_deriv = 2
        };
    
    private:
        std::vector<double> m_x,m_y;            // x,y coordinates of points
        // interpolation parameters
        // f(x) = a*(x-x_i)^3 + b*(x-x_i)^2 + c*(x-x_i) + y_i
        std::vector<double> m_a,m_b,m_c;        // spline coefficients
        double  m_b0, m_c0;                     // for left extrapol
        bd_type m_left, m_right;
        double  m_left_value, m_right_value;
        bool    m_force_linear_extrapolation;
    
    public:
        // set default boundary condition to be zero curvature at both ends
        spline(): m_left(second_deriv), m_right(second_deriv), m_left_value(0.0), m_right_value(0.0), m_force_linear_extrapolation(false)
        {
            ;
        }
    
        // optional, but if called it has to come be before set_points()
        void set_boundary(bd_type left, double left_value,
                          bd_type right, double right_value,
                          bool force_linear_extrapolation=false);
        void set_points(const std::vector<double>& x,  const std::vector<double>& y, bool cubic_spline=true);
        double operator() (double x) const;
    };
    
    // ---------------------------------------------------------------------
    // implementation part, which could be separated into a cpp file
    // ---------------------------------------------------------------------
    
    // band_matrix implementation
    // -------------------------
    
    band_matrix::band_matrix(int dim, int n_u, int n_l)
    {
        resize(dim, n_u, n_l);
    }
    void band_matrix::resize(int dim, int n_u, int n_l)
    {
        assert(dim>0);
        assert(n_u>=0);
        assert(n_l>=0);
        m_upper.resize(n_u+1);
        m_lower.resize(n_l+1);
        for(size_t i=0; i<m_upper.size(); i++) {
            m_upper[i].resize(dim);
        }
        for(size_t i=0; i<m_lower.size(); i++) {
            m_lower[i].resize(dim);
        }
    }
    int band_matrix::dim() const
    {
        if(m_upper.size()>0) {
            return m_upper[0].size();
        } else {
            return 0;
        }
    }
    
    
    // defines the new operator (), so that we can access the elements
    // by A(i,j), index going from i=0,...,dim()-1
    double & band_matrix::operator () (int i, int j)
    {
        int k=j-i;       // what band is the entry
        assert( (i>=0) && (i<dim()) && (j>=0) && (j<dim()) );
        assert( (-num_lower()<=k) && (k<=num_upper()) );
        // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
        if(k>=0)   return m_upper[k][i];
        else	    return m_lower[-k][i];
    }
    double band_matrix::operator () (int i, int j) const
    {
        int k=j-i;       // what band is the entry
        assert( (i>=0) && (i<dim()) && (j>=0) && (j<dim()) );
        assert( (-num_lower()<=k) && (k<=num_upper()) );
        // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
        if(k>=0)   return m_upper[k][i];
        else	    return m_lower[-k][i];
    }
    // second diag (used in LU decomposition), saved in m_lower
    double band_matrix::saved_diag(int i) const
    {
        assert( (i>=0) && (i<dim()) );
        return m_lower[0][i];
    }
    double & band_matrix::saved_diag(int i)
    {
        assert( (i>=0) && (i<dim()) );
        return m_lower[0][i];
    }
    
    // LR-Decomposition of a band matrix
    void band_matrix::lu_decompose()
    {
        int  i_max,j_max;
        int  j_min;
        double x;
    
        // preconditioning
        // normalize column i so that a_ii=1
        for(int i=0; i<this->dim(); i++) {
            assert(this->operator()(i,i)!=0.0);
            this->saved_diag(i)=1.0/this->operator()(i,i);
            j_min=std::max(0,i-this->num_lower());
            j_max=std::min(this->dim()-1,i+this->num_upper());
            for(int j=j_min; j<=j_max; j++) {
                this->operator()(i,j) *= this->saved_diag(i);
            }
            this->operator()(i,i)=1.0;          // prevents rounding errors
        }
    
        // Gauss LR-Decomposition
        for(int k=0; k<this->dim(); k++) {
            i_max=std::min(this->dim()-1,k+this->num_lower());  // num_lower not a mistake!
            for(int i=k+1; i<=i_max; i++) {
                assert(this->operator()(k,k)!=0.0);
                x=-this->operator()(i,k)/this->operator()(k,k);
                this->operator()(i,k)=-x;                         // assembly part of L
                j_max=std::min(this->dim()-1,k+this->num_upper());
                for(int j=k+1; j<=j_max; j++) {
                    // assembly part of R
                    this->operator()(i,j)=this->operator()(i,j)+x*this->operator()(k,j);
                }
            }
        }
    }
    // solves Ly=b
    std::vector<double> band_matrix::l_solve(const std::vector<double>& b) const
    {
        assert( this->dim()==(int)b.size() );
        std::vector<double> x(this->dim());
        int j_start;
        double sum;
        for(int i=0; i<this->dim(); i++) {
            sum=0;
            j_start=std::max(0,i-this->num_lower());
            for(int j=j_start; j<i; j++) sum += this->operator()(i,j)*x[j];
            x[i]=(b[i]*this->saved_diag(i)) - sum;
        }
        return x;
    }
    // solves Rx=y
    std::vector<double> band_matrix::r_solve(const std::vector<double>& b) const
    {
        assert( this->dim()==(int)b.size() );
        std::vector<double> x(this->dim());
        int j_stop;
        double sum;
        for(int i=this->dim()-1; i>=0; i--) {
            sum=0;
            j_stop=std::min(this->dim()-1,i+this->num_upper());
            for(int j=i+1; j<=j_stop; j++) sum += this->operator()(i,j)*x[j];
            x[i]=( b[i] - sum ) / this->operator()(i,i);
        }
        return x;
    }
    
    std::vector<double> band_matrix::lu_solve(const std::vector<double>& b,
            bool is_lu_decomposed)
    {
        assert( this->dim()==(int)b.size() );
        std::vector<double>  x,y;
        if(is_lu_decomposed==false) {
            this->lu_decompose();
        }
        y=this->l_solve(b);
        x=this->r_solve(y);
        return x;
    }
    
    // spline implementation
    // -----------------------
    void spline::set_boundary(spline::bd_type left, double left_value,
                              spline::bd_type right, double right_value,
                              bool force_linear_extrapolation)
    {
        assert(m_x.size()==0);          // set_points() must not have happened yet
        m_left=left;
        m_right=right;
        m_left_value=left_value;
        m_right_value=right_value;
        m_force_linear_extrapolation=force_linear_extrapolation;
    }
    
    void spline::set_points(const std::vector<double>& x,const std::vector<double>& y, bool cubic_spline)
    {
        assert(x.size()==y.size());
        assert(x.size()>2);
        m_x=x;
        m_y=y;
        int n = x.size();
        
    	// TODO: maybe sort x and y, rather than returning an error
        for(int i=0; i<n-1; i++)
    	{
            assert(m_x[i] < m_x[i+1]);
        }
    
        if(cubic_spline==true) 
    	{ 
    		// cubic spline interpolation
            // setting up the matrix and right hand side of the equation system
            // for the parameters b[]
            band_matrix A(n,1,1);
            std::vector<double>  rhs(n);
            for(int i=1; i<n-1; i++)
    		{
                A(i,i-1)=1.0/3.0*(x[i]-x[i-1]);
                A(i,i)=2.0/3.0*(x[i+1]-x[i-1]);
                A(i,i+1)=1.0/3.0*(x[i+1]-x[i]);
                rhs[i]=(y[i+1]-y[i])/(x[i+1]-x[i]) - (y[i]-y[i-1])/(x[i]-x[i-1]);
            }
            // boundary conditions
            if(m_left == spline::second_deriv) 
    		{
                // 2*b[0] = f''
                A(0,0)=2.0;
                A(0,1)=0.0;
                rhs[0]=m_left_value;
            } 
    		else if(m_left == spline::first_deriv) 
    		{
                // c[0] = f', needs to be re-expressed in terms of b:
                // (2b[0]+b[1])(x[1]-x[0]) = 3 ((y[1]-y[0])/(x[1]-x[0]) - f')
                A(0,0)=2.0*(x[1]-x[0]);
                A(0,1)=1.0*(x[1]-x[0]);
                rhs[0]=3.0*((y[1]-y[0])/(x[1]-x[0])-m_left_value);
            } 
    		else
    		{
                assert(false);
            }
           
    		if(m_right == spline::second_deriv)
    		{
                // 2*b[n-1] = f''
                A(n-1,n-1)=2.0;
                A(n-1,n-2)=0.0;
                rhs[n-1]=m_right_value;
            } 
    		else if(m_right == spline::first_deriv)
    		{
                // c[n-1] = f', needs to be re-expressed in terms of b:
                // (b[n-2]+2b[n-1])(x[n-1]-x[n-2])
                // = 3 (f' - (y[n-1]-y[n-2])/(x[n-1]-x[n-2]))
                A(n-1,n-1)=2.0*(x[n-1]-x[n-2]);
                A(n-1,n-2)=1.0*(x[n-1]-x[n-2]);
                rhs[n-1]=3.0*(m_right_value-(y[n-1]-y[n-2])/(x[n-1]-x[n-2]));
            } else {
                assert(false);
            }
    
            // solve the equation system to obtain the parameters b[]
            m_b=A.lu_solve(rhs);
    
            // calculate parameters a[] and c[] based on b[]
            m_a.resize(n);
            m_c.resize(n);
            for(int i=0; i<n-1; i++)
    		{
                m_a[i]=1.0/3.0*(m_b[i+1]-m_b[i])/(x[i+1]-x[i]);
                m_c[i]=(y[i+1]-y[i])/(x[i+1]-x[i])
                       - 1.0/3.0*(2.0*m_b[i]+m_b[i+1])*(x[i+1]-x[i]);
            }
        } 
    	else 
    	{ 
    		// linear interpolation
            m_a.resize(n);
            m_b.resize(n);
            m_c.resize(n);
            for(int i=0; i<n-1; i++)
    		{
                m_a[i]=0.0;
                m_b[i]=0.0;
                m_c[i]=(m_y[i+1]-m_y[i])/(m_x[i+1]-m_x[i]);
            }
        }
    
        // for left extrapolation coefficients
        m_b0 = (m_force_linear_extrapolation==false) ? m_b[0] : 0.0;
        m_c0 = m_c[0];
    
        // for the right extrapolation coefficients
        // f_{n-1}(x) = b*(x-x_{n-1})^2 + c*(x-x_{n-1}) + y_{n-1}
        double h=x[n-1]-x[n-2];
        // m_b[n-1] is determined by the boundary condition
        m_a[n-1]=0.0;
        m_c[n-1]=3.0*m_a[n-2]*h*h+2.0*m_b[n-2]*h+m_c[n-2];   // = f'_{n-2}(x_{n-1})
        if(m_force_linear_extrapolation==true)
            m_b[n-1]=0.0;
    }
    
    double spline::operator() (double x) const
    {
        size_t n=m_x.size();
        // find the closest point m_x[idx] < x, idx=0 even if x<m_x[0]
        std::vector<double>::const_iterator it;
        it=std::lower_bound(m_x.begin(),m_x.end(),x);
        int idx=std::max( int(it-m_x.begin())-1, 0);
    
        double h=x-m_x[idx];
        double interpol;
        if(x<m_x[0])
    	{
            // extrapolation to the left
            interpol=(m_b0*h + m_c0)*h + m_y[0];
        } 
    	else if(x>m_x[n-1])
    	{
            // extrapolation to the right
            interpol=(m_b[n-1]*h + m_c[n-1])*h + m_y[n-1];
        } 
    	else 
    	{
            // interpolation
            interpol=((m_a[idx]*h + m_b[idx])*h + m_c[idx])*h + m_y[idx];
        }
        return interpol;
    }
    
    } // namespace tk
    
    } // namespace
    
    #endif /* TK_SPLINE_H */
    

    显示效果如下

  • 相关阅读:
    Can't remove netstandard folder from output path (.net standard)
    website项目的reference问题
    The type exists in both DLLs
    git常用配置
    Map dependencies with code maps
    How to check HTML version of any website
    Bootstrap UI 编辑器
    网上职位要求对照
    Use of implicitly declared global variable
    ResolveUrl in external JavaScript file in asp.net project
  • 原文地址:https://www.cnblogs.com/flyinggod/p/12826647.html
Copyright © 2011-2022 走看看