zoukankan      html  css  js  c++  java
  • 「NOTE」进阶多项式小札

    好像又很久没写博客了
    总结成一个小札还是香啊 awa
    把 @tly 的博客学了一遍然后翻译成自己看得懂的备忘

    # 目录


    # 基础多项式版子

    你可能会在下面看到 BASIC_POLY 这么一个命名空间,代码如下。

    点击展开/折叠代码
    #define con(type) const type &
    const int N=1<<20,MOD=998244353;
    inline int add(con(int)a,con(int)b){return a+b>=MOD? a+b-MOD:a+b;}
    inline int sub(con(int)a,con(int)b){return a-b<0? a-b+MOD:a-b;}
    inline int mul(con(int)a,con(int)b){return int(1ll*a*b%MOD);}
    inline int ina_pow(con(int)a,con(int)b){return b?mul(ina_pow(mul(a,a),b>>1),(b&1)?a:1):1;}
    
    namespace BASIC_POLY{
    	int w[N+10],eta_lg[N+10],rev[N+10];
    	void init(){
    		w[0]=eta_lg[1]=0;w[1]=ina_pow(3,(MOD-1)>>20);
    		for(int i=2;i<N;i++){
    			w[i]=mul(w[i-1],w[1]);
    			eta_lg[i]=eta_lg[(i+1)>>1]+1;
    		}
    	}
    	void ntt(int *ary,con(int)n,con(int)typ){
    		for(int i=1;i<n;i++){
    			rev[i]=rev[i>>1]>>1|((i&1)*(n>>1));
    			if(rev[i]<i) swap(ary[i],ary[rev[i]]);
    		}
    		for(int i=1,ii=2;i<n;i<<=1,ii<<=1){
    			int u=N/ii;
    			for(int j=0;j<n;j+=ii){
    				int *a=ary+j,*b=ary+j+i,*p=w,q=*b;
    				for(int k=0;k<i;k++,a++,b++,p+=u,q=mul(*b,*p))
    					*b=sub(*a,q),*a=add(*a,q);
    			}
    		}
    		if(typ==-1){
    			reverse(ary+1,ary+n);
    			for(int i=0,ivn=MOD-(MOD-1)/n;i<n;i++)
    				ary[i]=mul(ary[i],ivn);
    		}
    	}
    	int modelize(con(int)l){return 1<<eta_lg[l];}
    	void polyInv(int *a,int *b,con(int)n){
    		if(n==1){b[0]=ina_pow(a[0],MOD-2);return;}
    		int m=(n+1)>>1,l=modelize(2*(m-1)+(n-1)+1);
    		polyInv(a,b,m);
    		static int tmp1[N+10],tmp2[N+10];
    		for(int i=0;i<n;i++) tmp1[i]=a[i];fill(tmp1+n,tmp1+l,0);
    		for(int i=0;i<m;i++) tmp2[i]=b[i];fill(tmp2+m,tmp2+l,0);
    		ntt(tmp1,l,1),ntt(tmp2,l,1);
    		for(int i=0;i<l;i++) tmp1[i]=mul(tmp2[i],sub(2,mul(tmp1[i],tmp2[i])));
    		ntt(tmp1,l,-1);
    		for(int i=0;i<n;i++) b[i]=tmp1[i];
    	}
    	void polyDet(int *a,int *b,con(int)n){
    		for(int i=1;i<n;i++) b[i-1]=mul(a[i],i);
    		b[n-1]=0;
    	}
    }
    

    # 多项式多点求值

    - 问题1

    给定一个最高次项为 (n-1) 的多项式 (f(x)),求 (f(a_0),f(a_1)cdots f(a_{m-1}))

    - 新的解法

    之前的做法要用到多项式取模这种大常数计算,而且代码还很复杂……

    先定义一种所谓的“差卷积”,记作

    [fotimes g=sum_{i=0}sum_{j=0}^if_ig_jx^{i-j} ]

    我们可以构造出 (g(x)=frac{1}{1-ax}) 使得 (h=fotimes g) 满足 ([x^0]h(x)=f(a))

    点击展开/折叠证明

    把 $g(x)$ 的闭形式展开,则 $(fotimes g)$ 相当于

    $$(f_0+f_1x+f_2x^2+cdots)otimes(1+ax+a^2x^2+cdots)$$

    观察其常数项,一定是 $f_ix^iotimes a^ix^i$,也就是 $sum a^if_i$。

    这样我们可以把一个点的求值问题转化成卷积问题。那怎么扩展到多点求值呢?

    关于差卷积,我们还需要一个性质——

    [(fotimes g)otimes h=fotimes(gcdot h) ]

    根据定义很容易理解,即 (f_ig_jh_kx^{i-j-k}=f_ix^i(g_jh_k)x^{-(j+k)})

    有了这个性质,就可以考虑用线段树的方式分治。我们希望能在线段树的第 (i) 个叶子处得到 ([x^0](fotimesfrac{1}{1-a_ix}))。那么线段树上区间 ([l,r]) 就维护

    [S_{l,r}=fotimesBig(prod_{i=l}^rfrac{1}{1-a_ix}Big) ]

    要从 ([l,r]) 递推到子区间 ([l,m]),我们需要计算:

    [S_{l,m}=S_{l,r}otimesBig(prod_{i=m+1}^r{1-a_ix}Big) ]

    代入即可证明。于是我们还需要对线段树的每个节点 ([l,r]) 先预处理出

    [T_{l,r}=prod_{i=l}^r1-a_ix ]

    (T_{l,r}) 的次数和节点的大小是一致的,因为线段树上每一层节点大小减半,可以直接 (O(nlog n)) 预处理。

    似乎现在就可以解决多点求值了?但还剩了一个非常棘手的问题——(S_{l,r}) 是一个形式幂级数,我们应该保留多少项才能计算出叶子处的常数项

    实际上我们只需要保留 (S_{l,r}) 的前 (r-l+1) 项,归纳证明如下:

    1. 归纳边界:叶子处只需要保留常数项;
    2. 考虑 ([l,r]) 的子节点 ([l,m])(S_{l,m}) 需要保留前 (m-l+1) 项(最高项为 (m-l) 次),而 (S_{l,m}) 的计算方法是:

      [S_{l,m}=S_{l,r}otimes T_{m+1,r} ]

      注意到 (T_{m+1,r}) 最高项是 (r-m-1) 次的,根据差卷积的计算方式,(S_{l,r}) 的最高项应为 (r-l) 次,即只需要保留前 (r-l+1) 项。

    这样的话复杂度就有保证了,每层项数减半,仍然可以 (O(nlog^2 n)) 解决。

    - 源代码1

    点击展开/折叠代码

    //BASIC_POLY 中主要是NTT这些东西
    namespace UPPER_POLY{
    	typedef vector<int> vint;
    	//普通卷积
    	vint polyMul(con(vint)a,con(vint)b){
    		static int tmp1[N+10],tmp2[N+10];
    		int na=(int)a.size(),nb=(int)b.size(),nc=na+nb-1;
    		vint c(nc);
    		int l=BASIC_POLY::modelize(nc);
    		for(int i=0;i<na;i++) tmp1[i]=a[i];fill(tmp1+na,tmp1+l,0);
    		for(int i=0;i<nb;i++) tmp2[i]=b[i];fill(tmp2+nb,tmp2+l,0);
    		BASIC_POLY::ntt(tmp1,l,1),BASIC_POLY::ntt(tmp2,l,1);
    		for(int i=0;i<l;i++) tmp1[i]=mul(tmp1[i],tmp2[i]);
    		BASIC_POLY::ntt(tmp1,l,-1);
    		for(int i=0;i<nc;i++) c[i]=tmp1[i];
    		return c;
    	}
    	//差卷积,卷积结果保留前 nc项
    	vint polySubMul(con(vint)a,con(vint)b,con(int)nc){
    		vint c=a;reverse(c.begin(),c.end());
    		c=polyMul(c,b);
    		c.resize(a.size()),reverse(c.begin(),c.end());
    		c.resize(nc);
    		return c;
    	}
    	vint seg[N];
    	#define idx(l,r) (((l)+(r))|((l)!=(r)))
    	//预处理出 T[l,r]
    	void build(int *a,con(int)le,con(int)ri){
    		if(le==ri){
    			seg[idx(le,ri)].clear();
    			seg[idx(le,ri)].push_back(1);
    			seg[idx(le,ri)].push_back(sub(0,a[le]));
    			return;
    		}
    		int mi=(le+ri)>>1;
    		build(a,le,mi),build(a,mi+1,ri);
    		seg[idx(le,ri)]=polyMul(seg[idx(le,mi)],seg[idx(mi+1,ri)]);
    	}
    	//p即 S[l,r]
    	void solve(int *res,con(int)le,con(int)ri,con(vint)p){
    		if(le==ri){res[le]=p[0];return;}
    		int mi=(le+ri)>>1;
    		solve(res,le,mi,polySubMul(p,seg[idx(mi+1,ri)],mi-le+1));
    		solve(res,mi+1,ri,polySubMul(p,seg[idx(le,mi)],ri-mi));
    	}
    	void multiVal(int *f,con(int)n,int *pos,con(int)m,int *r){
    		static int tmp1[N+10],tmp2[N+10];
    		build(pos,0,m-1);
    		int rt=idx(0,m-1);
    		for(int i=0,ii=min(n,(int)seg[rt].size());i<ii;i++)
    			tmp1[i]=seg[rt][i];
    		BASIC_POLY::polyInv(tmp1,tmp2,n);
    		//T[0,n-1] 求个逆再和 f差卷积得到 S[0,n-1]
    		vint v1(n),v2(n);
    		for(int i=0;i<n;i++) v1[i]=f[i],v2[i]=tmp2[i];
    		solve(r,0,m-1,polySubMul(v1,v2,m));
    	}
    }
    


    # 多项式快速插值

    - 问题2

    给定 (n)((x_i,y_i))(n-1) 次多项式 (f(x)) 满足 (forall i,f(x_i)=y_i),求 (f(x))

    - 拉格朗日插值

    由若干个 (n-1) 次多项式叠加得到 (f(x))。具体构造如下:

    [g_i(x)=y_iprod_{j eq i}frac{x-x_j}{x_i-x_j} ]

    这样构造的目的是 (g_i(x_j)) 当且仅当 (i=j)(g_i(x_j) eq0),于是可以将 (g_0(x),g_1(x)cdots g_{n-1}(x)) 直接相加得到 (f(x))

    - 快速插值

    考虑对拉格朗日插值进行优化(优化计算方式以加速)。观察 (f(x)) 的表达式:

    [f(x)=sum_{i=0}^{n-1}y_iprod_{j eq i}frac{x-x_j}{x_i-x_j}=sum_{i=0}^{n-1}y_ifrac{prod(x-x_j)}{prod(x_i-x_j)} ]

    先看看怎么计算分式的分母部分,该部分对于固定的 (i) 来说是一个常数,记为 (k_i)

    [k_i=prod_{j eq i}(x_i-x_j) ]

    根据极限的相关知识,不难证明下面这个式子是成立的:

    [k_i=lim_{x o x_i}frac{prodlimits_{j=0}^{n-1}(x_i-x_j)}{x-x_i} ]

    便于书写,记 (g(x)=prodlimits_{i=0}^{n-1}(x-x_i))。分式上下都趋近于 (0),符合洛必达法则的适用条件——由此可得 (k_i=g'(x_i))

    可以先用类似线段树的分治(本质是启发式合并)在 (O(nlog n)) 的复杂度内算出 (g(x)),然后 (O(n)) 多项式求导得到 (g'(x)),再多点求值就可以得到 (k_i) 了。

    再看 (f(x)),我们现在可以把它写成

    [f(x)=sum_{i=0}^{n-1}frac{y_i}{k_i}prod_{j eq i}(x-x_j) ]

    则需要解决后面这个式子。这其实可以用到多点求值的旧方法中的一个技巧——仍然是分治:

    分治

    记 $h_{l,r}(x)$:

    $$h_{l,r}(x)=sum_{i=l}^rfrac{y_i}{k_i}prod_{jin[l,r]}^{j eq i}(x-x_j)$$

    仍然用类似线段树的方法分治,考虑如何从 $h_{l,m}(x)$ 和 $h_{m+1,r}(x)$ 合并到 $h_{l,r}(x)$。

    $$ egin{aligned} S_{l,r}&=prod_{i=l}^r(x-x_i)\ h_{l,r}(x)&=h_{l,m}S_{m+1,r}+S_{l,m}h_{m+1,r} end{aligned} $$

    这样就可以 $O(nlog^2 n)$ 求出答案。

    总的复杂度就是 (O(n^2+nlog^2 n))

    - 源代码2

    点击展开/折叠代码
    namespace UPPER_POLY{
    	typedef vector<int> vint;
    	vint polyAdd(con(vint)a,con(vint)b){
    		int na,nb;
    		vint c(max(na=(int)a.size(),nb=(int)b.size()));
    		for(int i=0;i<na;i++) c[i]=a[i];
    		for(int i=0;i<nb;i++) c[i]=add(c[i],b[i]);
    		return c;
    	}
    	vint polyMul(con(vint)a,con(vint)b){
    		int na=(int)a.size(),nb=(int)b.size(),nc=na+nb-1,l=BASIC_POLY::modelize(nc);
    		static int tmp1[N+10],tmp2[N+10];
    		for(int i=0;i<na;i++) tmp1[i]=a[i];fill(tmp1+na,tmp1+l,0);
    		for(int i=0;i<nb;i++) tmp2[i]=b[i];fill(tmp2+nb,tmp2+l,0);
    		BASIC_POLY::ntt(tmp1,l,1),BASIC_POLY::ntt(tmp2,l,1);
    		for(int i=0;i<l;i++) tmp1[i]=mul(tmp1[i],tmp2[i]);
    		BASIC_POLY::ntt(tmp1,l,-1);
    		vint c(nc);
    		for(int i=0;i<nc;i++) c[i]=tmp1[i];
    		return c;
    	}
    	vint polySubMul(con(vint)a,con(vint)b,con(int)nc){
    		vint c=a;
    		reverse(c.begin(),c.end());
    		c=polyMul(c,b);
    		c.resize(a.size());
    		reverse(c.begin(),c.end());
    		c.resize(nc);
    		return c;
    	}
    	vint seg[N+10];
    	#define idx(l,r) (((l)+(r))|((l)!=(r)))
    	void build(int *a,con(int)le,con(int)ri){
    		if(le==ri){
    			int u=idx(le,ri);
    			seg[u].clear();
    			seg[u].push_back(1),seg[u].push_back(sub(0,a[le]));
    			return;
    		}
    		int mi=(le+ri)>>1;
    		build(a,le,mi),build(a,mi+1,ri);
    		seg[idx(le,ri)]=polyMul(seg[idx(le,mi)],seg[idx(mi+1,ri)]);
    	}
    	void solve1(int *res,con(int)le,con(int)ri,con(vint)p){
    		if(le==ri){res[le]=p[0];return;}
    		int mi=(le+ri)>>1;
    		solve1(res,le,mi,polySubMul(p,seg[idx(mi+1,ri)],mi-le+1));
    		solve1(res,mi+1,ri,polySubMul(p,seg[idx(le,mi)],ri-mi));
    	}
    	void multiVal(int *f,con(int)n,int *pos,con(int)m,int *res){
    		build(pos,0,m-1);
    		static int tmp1[N+10],tmp2[N+10];
    		int rt=idx(0,m-1);
    		for(int i=0,ii=(int)seg[rt].size();i<n;i++)
    			tmp1[i]=i<ii?seg[rt][i]:0;
    		BASIC_POLY::polyInv(tmp1,tmp2,n);
    		vint p1(n),p2(n);
    		for(int i=0;i<n;i++) p1[i]=f[i],p2[i]=tmp2[i];
    		solve1(res,0,m-1,polySubMul(p1,p2,m));
    	}
    	void build2(int *a,con(int)le,con(int)ri){
    		if(le==ri){
    			seg[idx(le,ri)].clear();
    			seg[idx(le,ri)].push_back(sub(0,a[le])),seg[idx(le,ri)].push_back(1);
    			return;
    		}
    		int mi=(le+ri)>>1;
    		build2(a,le,mi),build2(a,mi+1,ri);
    		seg[idx(le,ri)]=polyMul(seg[idx(le,mi)],seg[idx(mi+1,ri)]);
    	}
    	vint solve2(int *y,con(int)le,con(int)ri){
    		if(le==ri) return vint(1,y[le]);
    		int mi=(le+ri)>>1;
    		return polyAdd(polyMul(solve2(y,le,mi),seg[idx(mi+1,ri)]),polyMul(seg[idx(le,mi)],solve2(y,mi+1,ri)));
    	}
    	vint multiPoly(int *pos,int *val,int n){
    		static int tmp1[N+10],tmp2[N+10];
    		build2(pos,0,n-1);
    		for(int i=0,rt=idx(0,n-1);i<=n;i++) tmp1[i]=seg[rt][i];
    		BASIC_POLY::polyDet(tmp1,tmp1,n+1);
    		multiVal(tmp1,n,pos,n,tmp2);
    		build2(pos,0,n-1);
    		//getinv
    		tmp1[0]=tmp2[0];for(int i=1;i<n;i++) tmp1[i]=mul(tmp1[i-1],tmp2[i]);
    		tmp1[n-1]=ina_pow(tmp1[n-1],MOD-2);
    		for(int i=n-2;~i;i--){
    			int tmp1i=mul(tmp1[i+1],tmp2[i+1]);
    			tmp1[i+1]=mul(tmp1[i+1],tmp1[i]);
    			tmp1[i]=tmp1i;
    		}
    		for(int i=0;i<n;i++) tmp1[i]=mul(tmp1[i],val[i]);
    		return solve2(tmp1,0,n-1);
    	}
    }
    

    THE END

    Thank for reading!

    若陪伴是我所有
    那痕迹 藏匿了整个宇宙

    ——《听风捕梦》By (×28)双笙/封茗囧菌/司南

    > Link 听风捕梦-网易云

    欢迎转载٩(๑❛ᴗ❛๑)۶,请在转载文章末尾附上原博文网址~
  • 相关阅读:
    mac中导出CSV格式在excel中乱码
    Android Gradle与Gradle插件的对应关系
    【算法】二叉树的前序、中序、后序、层序遍历和还原。
    关于Java虚拟机
    从Java synchronized和volatile说起
    【程小白】Java基本特性
    Android一键锁屏APP
    Fragment的生命周期
    学习大数据必须了解的大数据开发课程大纲
    学习大数据这三个关键技术是一定要掌握!
  • 原文地址:https://www.cnblogs.com/LuckyGlass-blog/p/14417254.html
Copyright © 2011-2022 走看看