zoukankan      html  css  js  c++  java
  • 多项式快速插值

    多项式快速插值

    https://www.luogu.com.cn/problem/P5158

    给出 (n+1) 个点 ((x_0,y_0),(x_1,y_1),cdots,(x_n,y_n)) .求一个 (n) 次多项式 (f(x)) ,满足 (f(x_i) equiv y_i mod 998244353)

    (1 le n le 100000)

    (0 le x_i,y_i < 998244353) , (x_i) 互不相同

    Tutorial

    https://www.luogu.com.cn/blog/Minamoto/solution-p5158

    根据拉格朗日插值公式计算

    [egin{align} f(x) &= sum_{i=0}^n y_i prod_{j ot= i} dfrac {x-x_j}{x_i-x_j} \ &= sum_{i = 0}^n dfrac {y_i}{prod_{j ot= i}x_i - x_j} prod_{j ot= i} (x-x_j) end{align} ]

    考虑前面分母 (prod_{j ot= i} (x_i-x_j)) 部分 ,若设 (g(x) = prod_{j=0}^n (x-x_j),h(x) = x-x_i) ,那么这个值就是 (dfrac {g(x_i)}{h(x_i)}) ,发现此时分子分母均为 (0)

    洛必达法则

    如果

    [lim_{x o a} f(x) = 0, lim_{x o a} g(x) = 0 ]

    那么

    [lim_{x o a} dfrac {f(x)}{g(x)} = lim_{x o a} dfrac {f'(x)}{g'(x)} ]

    则根据洛必达法则,我们要计算的就是 (g'(x_i)) .

    那么可以用分治FFT算出 (g) ,然后用多项式多点求值算出每个 (g'(x_i)) ,这一部分的复杂度为 (O(n log^2n)) .

    之后对于这个式子分治FFT计算即可,设 (mid = lfloor dfrac n2 floor, a_i = dfrac {y_i}{g'(x_i)}) ,则

    [egin{align} f(x) &= sum_{i=0}^n a_i prod_{j ot= i} (x-x_j) \ &= prod_{j=mid+1}^{n} (x-x_j) sum_{i=0}^{mid} a_i prod_{j in [0,mid],j ot =i} (x-x_j) +prod_{j=0}^{mid} (x-x_j) sum_{i=mid+1}^n a_i prod_{j in [mid+1,n],j ot= i} (x-x_j) end{align} ]

    总复杂度 (O(n log^2 n))

    Code

    多点求值的时候,区间长度较小时直接暴力求值.

    #include <algorithm>
    #include <cstdio>
    #include <iostream>
    #include <vector>
    #define debug(...) fprintf(stderr,__VA_ARGS__)
    #define inver(a) power(a,mod-2)
    #define lson u<<1,l,mid
    #define rson u<<1|1,mid+1,r
    using namespace std;
    typedef long long ll;
    const int mod=998244353;
    const int maxn=1e5+50;
    const int maxnode=maxn<<2;
    int n;
    int a[maxn];
    int x[maxn],y[maxn];
    vector<int> P[maxnode];
    inline int add(int x) {return x>=mod?x-mod:x;}
    inline int sub(int x) {return x<0?x+mod:x;}
    ll power(ll x,ll y)
    {
    	ll re=1;
    	while(y)
    	{
    		if(y&1) re=re*x%mod;
    		x=x*x%mod;
    		y>>=1;
    	}
    	return re;
    }
    namespace pol
    {
    	vector<int> w[2][25];
    	void init()
    	{
    		static const int g=3;
    		int r=inver(g);
    		for(int i=1,s=0;i<maxnode;i<<=1,++s)
    		{
    			ll w0=power(g,(mod-1)/(i<<1)); w[0][s].push_back(1);
    			ll w1=power(r,(mod-1)/(i<<1)); w[1][s].push_back(1);
    			for(int k=1;k<i;++k)
    			{
    				w[0][s].push_back(w[0][s][k-1]*w0%mod);
    				w[1][s].push_back(w[1][s][k-1]*w1%mod);
    			}
    		}
    	}
    	void FFT(int *a,int n,int f)
    	{
    		int d=f==-1;
    		for(int i=0,j=0;i<n;++i)
    		{
    			if(i<j) swap(a[i],a[j]);
    			for(int l=n>>1;(j^=l)<l;l>>=1);
    		}
    		for(int i=1,s=0;i<n;i<<=1,++s)
    		{
    			for(int j=0,p=i<<1;j<n;j+=p)
    			{
    				int *u=a+j;
    				int *v=a+j+i;
    				for(int k=0;k<i;++k,++u,++v)
    				{
    					int x=*u;
    					int y=(ll)*v*w[d][s][k]%mod;
    					*u=add(x+y);
    					*v=sub(x-y);
    				}
    			}
    		}
    		if(f==-1)
    		{
    			ll r=inver(n);
    			for(int i=0;i<n;++i) a[i]=a[i]*r%mod;
    		}
    	}
    	void convenx(vector<int> &A,vector<int> &B,vector<int> &C,int degC)
    	{
    		int a[maxnode],b[maxnode];
    		int degA=A.size()-1,degB=B.size()-1;
    		int n=1; while(n<=degA+degB) n<<=1;
    		copy(A.begin(),A.end(),a),fill(a+degA+1,a+n,0);
    		copy(B.begin(),B.end(),b),fill(b+degB+1,b+n,0);
    		FFT(a,n,1),FFT(b,n,1);
    		for(int i=0;i<n;++i) a[i]=(ll)a[i]*b[i]%mod;
    		FFT(a,n,-1); 
    		C.resize(degC+1);
    		for(int i=0;i<=degC;++i) C[i]=a[i];
    	}
    	void inverse(vector<int> &A,int n,vector<int> &B)
    	{
    		static int a[maxnode],b[maxnode];
    		if(n==1)
    		{
    			B[0]=inver(A[0]);
    			return;
    		}
    		int mid=(n+1)>>1; 
    		inverse(A,mid,B);
    		copy(A.begin(),A.begin()+n,a);
    		copy(B.begin(),B.begin()+mid,b),fill(b+mid,b+n,0);
    		int deg=1; while(deg<=(n<<1)) deg<<=1;
    		fill(a+n,a+deg,0);
    		fill(b+n,b+deg,0);
    		FFT(a,deg,1),FFT(b,deg,1);
    		for(int i=0;i<deg;++i)
    			a[i]=(ll)sub(2-(ll)a[i]*b[i]%mod)*b[i]%mod;
    		FFT(a,deg,-1);
    		for(int i=0;i<n;++i) B[i]=a[i];
    	}
    	void module(vector<int> &A,vector<int> &B,vector<int> &R)
    	{
    		int n=A.size()-1,m=B.size()-1; if(n<m) {R=A; return;}
    		vector<int> A0=B; reverse(A0.begin(),A0.end()),A0.resize(n-m+1);
    		vector<int> B0; B0.resize(n-m+1); inverse(A0,n-m+1,B0);
    		A0=A,reverse(A0.begin(),A0.end()),A0.resize(n-m+1);
    		vector<int> D; pol::convenx(A0,B0,D,n-m); reverse(D.begin(),D.end());
    		pol::convenx(B,D,R,m-1);
    		for(int i=0;i<m;++i) R[i]=sub(A[i]-R[i]);
    	}
    }
    void divide(int u,int l,int r)
    {
    	if(l==r)
    	{
    		P[u].push_back(sub(-x[l]));
    		P[u].push_back(1);
    		return;
    	}
    	int mid=(l+r)>>1;
    	divide(lson);
    	divide(rson);
    	pol::convenx(P[u<<1],P[u<<1|1],P[u],r-l+1);
    }
    void evaluation(int u,int l,int r,vector<int> &f)
    {
    	vector<int> A; pol::module(f,P[u],A);
    	if(r-l<=20)
    	{
    		for(int i=l;i<=r;++i)
    		{
    			for(int j=0,b=1;j<f.size();++j,b=(ll)b*x[i]%mod) 
    				a[i]=(a[i]+(ll)f[j]*b)%mod;
    		}
    		return;
    	}
    	int mid=(l+r)>>1;
    	evaluation(lson,A);
    	evaluation(rson,A);
    }
    vector<int> interpolation(int u,int l,int r)
    {
    	if(l==r)
    	{
    		vector<int> re;
    		re.push_back(a[l]);
    		return re;
    	}
    	int mid=(l+r)>>1;
    	vector<int> L=interpolation(lson);
    	vector<int> R=interpolation(rson);
    	vector<int> re; pol::convenx(L,P[u<<1|1],re,r-l);
    	vector<int> t; pol::convenx(R,P[u<<1],t,r-l);
    	for(int i=0;i<re.size();++i) re[i]=add(re[i]+t[i]);
    	return re;
    }
    void sol()
    {
    	divide(1,0,n);
    	vector<int> g; g.resize(n+1);
    	for(int i=0;i<=n;++i) g[i]=(ll)P[1][i+1]*(i+1)%mod;
    	evaluation(1,0,n,g);
    	for(int i=0;i<=n;++i) a[i]=y[i]*inver(a[i])%mod;
    	vector<int> f=interpolation(1,0,n);
    	for(int i=0;i<=n;++i)
    	{
    		if(i) printf(" ");
    		printf("%d",f[i]);
    	}
    	printf("
    ");
    }
    int main()
    {
    	pol::init();
    	scanf("%d",&n),--n;
    	for(int i=0;i<=n;++i) scanf("%d%d",&x[i],&y[i]);
    	sol();
    	return 0;
    }
    
  • 相关阅读:
    react router实现多级嵌套路由默认跳转
    【转载】git 撤销,放弃本地修改
    js中RGB值与16进制颜色值进行互转
    【转载】whistle 使用实践
    程序员腰突经历分享(中)
    在非洲运营互联网系统-如何搞定支付?
    30岁后遇不治之症(上)
    递归把path字符串构造成递归数组
    使用go开发公众号之 关注公众号发送小程序卡片
    excel 函数经验答题
  • 原文地址:https://www.cnblogs.com/ljzalc1022/p/12917885.html
Copyright © 2011-2022 走看看