zoukankan      html  css  js  c++  java
  • 拉格朗日插值

    公式:$f(x)=sum_{i=1}^{n} y_{i} prod_{i eq j} frac{x-x_{j}}{x_{i}-x_{j}}$.    

    这个式子正常算的话是 $O(n^2)$ 的,如果遇到 $x$ 是连续的情况可以优化到 $O(n log n)$.   

    但是有些时候我们只知道 $f(x)$ 在 $x=k$ 时的点值是不够的,有时必须求出这个多项式每一位系数.    

    多项式快速插值可以做到 $O(n log^2 n)$,但是快速插值非常非常难写,用处并不多.     

    相比之下,有一种简易的写法可以在 $O(n^2)$ 的时间复杂度内通过 $n$ 个不同的点来还原一个 $n-1$ 次多项式.    

    插值公式中 $prod_{j eq i} (x-x_{j})$ 是比较难求的,其他地方由于都是基于整数的运算,所以比较简单.     

    先令 $f_{i,j}$ 表示考虑前 $i$ 个点 $(x,y)$,$x^j$ 前的系数.    

    那么有转移:$f_{i,j}=f_{i-1,j-1}+f_{i-1,j} imes (-x_{i})$ 即分别表示当前位的贡献为 $x^1 / -x_{i}$.      

    求出这个后,我们枚举 $i$,然后想 $O(n)$ 计算 $h(x)=prod_{i eq j} (x-x_{j})$.   

    令 $k1[i],k2[i]$ 分别表示 $h(x)$ 的 $x^i$ 前的系数,强制让第 $i$ 位贡献 $x^1$ 时 $x^i$ 前的系数.    

    由于有 $k2$ 这个强制贡献的状态,转移就比较简单:

    $k1[i] leftarrow k2[i+1]$ 

    $f_{n,i}=k1[i] imes (-x_{i}) +k2[i] Rightarrow k2[i]=f_{n,i}+k1[i] imes (x_{i})$.    

    算出 $k2$ 后把 $y_{i}$ 及插值公式中分母的贡献乘上然后累加到答案数组中即可.      

    应用:

    求 $sum_{i=1}^{n} i^k$.    

    这是一个关于 $n$ 的 $k+1$ 次多项式.  

    所以可以取 $k+2$ 个点带进去,然后用拉格朗日插值法来求值.    

    具体,$f(k)=sum_{i=1}^{n} y_{i} prod_{}^{i eq j}frac{k-x_{j}}{x_{i}-x_{j}}$    

    由于点可以做到取 $x$ 连续的,所以提前预处理前缀/后缀积极可以做到 $O(n log n)$.  

    code: 

    #include <cstdio>  
    #include <vector>  
    #include <cstring>
    #include <algorithm>  
    #define N 1000009  
    #define ll long long 
    #define mod 1000000007
    #define setIO(s) freopen(s".in","r",stdin) 
    using namespace std;   
    int f[N]; 
    int ifac[N],fac[N],pre[N],suf[N],inv[N],n,K;  
    int qpow(int x,int y) {  
        int tmp=1; 
        for(;y;y>>=1,x=(ll)x*x%mod)  
            if(y&1) tmp=(ll)tmp*x%mod;   
        return tmp;    
    } 
    int INV(int x) { return qpow(x,mod-2); }         
    void init() {  
        ifac[0]=fac[0]=inv[1]=1;  
        for(int i=2;i<N;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;     
        inv[0]=1;  
        for(int i=1;i<N;++i) {
            fac[i]=(ll)fac[i-1]*i%mod;
            ifac[i]=(ll)ifac[i-1]*inv[i]%mod;   
        }
        pre[0]=1,suf[n+1]=1;  
        for(int i=1;i<=n;++i)  pre[i]=(ll)pre[i-1]*(K-i+mod)%mod;      
        for(int i=n;i>=1;--i)  suf[i]=(ll)suf[i+1]*(K-i+mod)%mod;             
    } 
    int sol() { 
        int ans=0; 
        for(int i=1;i<=n;++i) {  
            int a1=(ll)ifac[i-1]*ifac[n-i]%mod; 
            if((n-i)&1) a1=(ll)a1*(mod-1)%mod;    
            int a2=(ll)pre[i-1]*suf[i+1]%mod;   
            (ans+=(ll)f[i]*a1%mod*a2%mod)%=mod; 
        }  
        return ans;  
    }
    int main() {          
        // setIO("input");   
        scanf("%d%d",&K,&n),n+=2;                           
        init();      
        for(int i=1;i<=n;++i) { 
            f[i]=(ll)(f[i-1]+qpow(i,n-2))%mod;        
        }    
        printf("%d
    ",sol());  
        return 0; 
    }
    

      

    还原多项式系数

    #include <cstdio>  
    #include <cstring>
    #include <algorithm>  
    #define N 2008   
    #define ll long long
    #define mod 998244353
    #define setIO(s) freopen(s".in","r",stdin) 
    using namespace std;     
    int f[N][N],k1[N],k2[N],s[N],n;    
    struct point {  
        int x,y;  
        point(int x=0,int y=0):x(x),y(y){}  
    }a[N];  
    int ADD(int x,int y) { 
        return (ll)(x+y)%mod; 
    } 
    int DEC(int x,int y) { 
        return (ll)(x-y+mod)%mod; 
    } 
    int MUL(int x,int y) { 
        return (ll)x*y%mod;   
    }
    int qpow(int x,int y) { 
        int tmp=1; 
        for(;y;y>>=1,x=MUL(x,x)) 
            if(y&1) { 
                tmp=MUL(tmp,x);  
            } 
        return tmp;   
    }
    int get_inv(int x) { return qpow(x,mod-2); }  
    void init() {  
        f[0][0]=1;   
        for(int i=1;i<=n;++i) {  
            for(int j=1;j<=i;++j) {   
                f[i][j]=ADD(f[i-1][j-1],MUL(mod-a[i].x,f[i-1][j]));        
            }
            f[i][0]=MUL(f[i-1][0],mod-a[i].x);               
        }
    }           
    int main() {  
        // setIO("input");
        int X;   
        scanf("%d%d",&n,&X);  
        for(int i=1;i<=n;++i) {
            scanf("%d%d",&a[i].x,&a[i].y);  
        }
        init();               
        for(int i=1;i<=n;++i) {   
            for(int j=0;j<=n;++j) k2[j]=f[n][j];   
            for(int j=n-1;j>=0;--j) { 
                k1[j]=k2[j+1];   
                k2[j]=ADD(k2[j],MUL(k1[j],a[i].x));           
            }           
            int inv=1;  
            for(int j=1;j<=n;++j) 
                if(i!=j) {
                    inv=(ll)inv*(a[i].x-a[j].x+mod)%mod;   
                }
            inv=get_inv(inv);   
            for(int j=0;j<=n-1;++j) { 
                (s[j]+=(ll)inv*a[i].y%mod*k1[j]%mod)%=mod;   
            }
        }
        int ans=0;      
        for(int i=n-1;i>=0;--i) { 
            ans=(ll)((ll)ans*X%mod+s[i])%mod;    
        }  
        printf("%d
    ",ans); 
        return 0;
    }
    

      

    例题

    CF917D Stranger Trees

    给你一颗树,求 $n$ 个点有多少个生成树满足该生成树与给定树有 $k$ 条边是重合的.    

    题解:

    先对完全图构建矩阵,然后将原树上的边 $(x,y)$ 在矩阵中的边权标记成 $x^1$,其余边权为 $1$.  

    矩阵树定理求的是所有生成树边权乘积之和,那么要是可以对含 $x$ 的矩阵求行列式的话可以直接得出答案.   

    但是复杂度太高,而且难写(写不了)    

    所以用 $n$ 个不同的整数来替换那个 $x^1$,然后跑出来 $n$ 个结果,用拉格朗日插值还原出多项式的系数即可.    

    #include <cstdio> 
    #include <vector>
    #include <cstring>
    #include <algorithm>      
    #define N 103
    #define ll long long
    #define mod 1000000007
    #define setIO(s) freopen(s".in","r",stdin)
    using namespace std;     
    int n;   
    int A[N],B[N]; 
    int f[N][N],k1[N],k2[N],ans[N];
    int deg[N][N],con[N][N],a[N][N];  
    struct point {
        int x,y; 
        point(int x=0,int y=0):x(x),y(y){} 
    }p[N]; 
    int qpow(int x,int y) {
        int tmp=1;
        for(;y;y>>=1,x=(ll)x*x%mod)
            if(y&1) {  
                tmp=(ll)tmp*x%mod; 
            }
        return tmp;
    }
    int get_inv(int x) {
        return qpow(x,mod-2); 
    }    
    int ADD(int x,int y) {
        return (ll)(x+y)%mod;
    }
    int DEC(int x,int y) {
        return (ll)(x-y+mod)%mod; 
    } 
    int MUL(int x,int y) {
        return (ll)x*y%mod;  
    }
    int gauss() { 
        int ans=1; 
        for(int i=1;i<n;++i) {
            for(int j=i+1;j<n;++j) {   
                while(a[j][i]) {
                    int t=a[i][i]/a[j][i]; 
                    for(int k=i;k<n;++k) {
                        a[i][k]=DEC(a[i][k],MUL(t,a[j][k]));         
                    }
                    swap(a[j],a[i]);  
                    ans=(ll)ans*(mod-1)%mod;   
                }
            }
            if(!a[i][i]) {
                return 0; 
            }
        }
        for(int i=1;i<n;++i) {
            ans=(ll)ans*a[i][i]%mod;   
        }
        return ans;  
    }        
    int cal(int val) { 
        for(int i=1;i<=n;++i) { 
            for(int j=1;j<=n;++j) {
                a[i][j]=mod-1;
            }
        }  
        for(int i=1;i<=n;++i) {
            a[i][i]=n-1;   
        }
        for(int i=1;i<n;++i) { 
            int x=A[i],y=B[i];                       
            a[x][x]=(ll)(DEC(a[x][x],1)+val)%mod;   
            a[y][y]=(ll)(DEC(a[y][y],1)+val)%mod;   
            a[x][y]=(ll)(a[x][y]+1-val+mod)%mod; 
            a[y][x]=(ll)(a[y][x]+1-val+mod)%mod;
        }
        return gauss();  
    }
    void init() { 
        f[0][0]=1;  
        for(int i=1;i<=n;++i) {
            for(int j=1;j<=i;++j)
                f[i][j]=ADD(f[i-1][j-1],MUL(f[i-1][j],mod-p[i].x));    
            f[i][0]=(ll)f[i-1][0]*(mod-p[i].x)%mod;  
        }
    }
    int main() { 
        // setIO("input");  
        scanf("%d",&n); 
        int x,y,z;
        for(int i=1;i<n;++i) {           
            scanf("%d%d",&A[i],&B[i]); 
        }
        for(int i=1;i<=n;++i) { 
            p[i].x=i;  
            p[i].y=cal(i); 
        }    
        init(); 
        for(int i=1;i<=n;++i) { 
            int inv=1; 
            for(int j=1;j<=n;++j) {    
                if(i!=j) inv=(ll)inv*(p[i].x-p[j].x+mod)%mod;  
            }
            inv=get_inv(inv);      
            for(int j=0;j<=n;++j) {
                k2[j]=f[n][j];
            }
            for(int j=n-1;j>=0;--j) {
                k1[j]=k2[j+1];    
                k2[j]=ADD(k2[j],MUL(p[i].x,k1[j]));  
            }                
            for(int j=0;j<=n-1;++j) {
                ans[j]=ADD(ans[j],(ll)k1[j]*inv%mod*p[i].y%mod);  
            }
        }    
        for(int i=0;i<n;++i) {
            printf("%d ",ans[i]);  
        }
        return 0; 
    }
    

      

     LuoguP4463 [集训队互测2012] calc

    朴素的 DP 非常好列:$f[i][j]$ 表示选了 $i$ 个数,且值域为 $[1,j]$ 的总价值和.    

    那么有 $f[i][j]=f[i-1][j-1] imes j+f[i][j-1]$,直接算的话复杂度是 $O(nD)$ 的.   

    但是我们可以猜测这是一个关于 $j$ 的 $g_{i}$ 次多项式.    

    有一个结论:对于 $n$ 次多项式 $h(x)$,满足 $h(x)-h(x-1)$ 是 $n-1$ 次多项式.   

    那么有 $f[i][j]-f[i][j-1]=f[i-1][j-1] imes j$.    

    将 $g$ 带入,有 $g_{i}-1=g_{i-1}+1$.    

    即 $g_{i}=g_{i-1}+2$,说明这是一个关于 $j$ 的 $2 imes i$ 次多项式.    

    那么我们就求出 $f[n][1...2n+1]$ 后将值带入,然后拉格朗日插值来插一下就行了.   

    code: 

    #include <cstdio>  
    #include <cstring>
    #include <algorithm> 
    #define N 2002
    #define ll long long 
    #define setIO(s) freopen(s".in","r",stdin)
    using namespace std; 
    int D,n,mod,tot,f[N][N],fac[N];  
    void init() {
        fac[0]=1;  
        for(int i=1;i<N;++i) {
            fac[i]=(ll)fac[i-1]*i%mod;   
        }
    }
    struct point {
        int x,y;  
        point(int x=0,int y=0):x(x),y(y){}  
    }a[N];  
    int qpow(int x,int y) {
        int tmp=1; 
        for(;y;y>>=1,x=(ll)x*x%mod)  
            if(y&1) tmp=(ll)tmp*x%mod; 
        return tmp; 
    }  
    int get_inv(int x) {
        return qpow(x,mod-2);   
    }
    int calc() {
        int ans=0;   
        for(int i=1;i<=tot;++i) {  
            int inv=1,up=1;    
            for(int j=1;j<=tot;++j) {
                if(i==j) continue;     
                up=(ll)up*(D-a[j].x+mod)%mod;    
                inv=(ll)inv*(a[i].x-a[j].x+mod)%mod;   
            }
            inv=get_inv(inv);   
            (ans+=(ll)a[i].y*up%mod*inv%mod)%=mod;  
        }
        return ans;   
    }
    int main() {          
        // setIO("input");    
        scanf("%d%d%d",&D,&n,&mod);  
        init(); 
        for(int i=0;i<=2*n+1;++i) f[0][i]=1;   
        for(int i=1;i<=n;++i) {
            for(int j=1;j<=2*n+1;++j) {
                f[i][j]=(ll)(f[i][j-1]+(ll)f[i-1][j-1]*j%mod)%mod;  
            }
        } 
        for(int i=1;i<=2*n+1;++i) {
            a[++tot]=point(i,f[n][i]);    
        }
        printf("%d
    ",(ll)calc()*fac[n]%mod);   
        return 0;   
    }
    

      

  • 相关阅读:
    [Leetcode]@python 89. Gray Code
    [Leetcode]@python 88. Merge Sorted Array.py
    [Leetcode]@python 87. Scramble String.py
    [Leetcode]@python 86. Partition List.py
    [leetcode]@python 85. Maximal Rectangle
    0523BOM
    0522作业星座
    0522dom
    0520
    0519作业
  • 原文地址:https://www.cnblogs.com/guangheli/p/13329921.html
Copyright © 2011-2022 走看看