[BZOJ3684]大朋友和多叉树(拉格朗日反演)
题面
给定整数(n)和集合(S(1 otin S)),求有(n)个节点且每个非叶子节点的儿子数量(in S)的无标号有根树的数量。节点的孩子有顺序.(n,|S|leq 10^5)
分析
设这些树的OGF为(T(x)),根据定义,一棵树可以是单个叶子节点,或者是非叶子节点拼上(i)个子树组成的序列。
[T(x)=x+sum_{i in S}T^i(x)
]
构造函数(g(w)=w-sum_{i in S}w^i),容易发现(g(T(x))=x)。用拉格朗日反演求出(T)的(n)次项系数。
(frac{1}{n}[w^{n-1}](frac{w}{g(w)})^n=frac{1}{n}[w^{n-1}]((frac{g(w)}{w})^n)^{-1})
除以(w)相当于系数向左平移,然后多项式快速幂再求逆即可。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define maxn 400000
#define mod 950009857
using namespace std;
typedef long long ll;
inline ll fast_pow(ll x,ll k) {
ll ans=1;
while(k) {
if(k&1) ans=ans*x%mod;
x=x*x%mod;
k>>=1;
}
return ans;
}
inline ll inv(ll x) {
return fast_pow(x,mod-2);
}
const ll G=5,invG=inv(G);
int rev[maxn+5];
void NTT(ll *x,int n,int type) {
for(int i=0; i<n; i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int len=1; len<n; len*=2) {
int sz=len*2;
ll gn1=fast_pow((type==1?G:invG),(mod-1)/sz);
for(int l=0; l<n; l+=sz) {
int r=l+len-1;
ll gnk=1;
for(int i=l; i<=r; i++) {
ll tmp=x[i+len];
x[i+len]=(x[i]-gnk*tmp%mod+mod)%mod;
x[i]=(x[i]+gnk*tmp%mod)%mod;
gnk=gnk*gn1%mod;
}
}
}
if(type==-1) {
ll invn=inv(n);
for(int i=0; i<n; i++) x[i]=x[i]*invn%mod;
}
}
void poly_mul(ll *a,ll *b,ll *c,int n,int m) {
static ll ta[maxn+5],tb[maxn+5];
int N=1,L=0;
while(N<n+m-1) {
N*=2;
L++;
}
for(int i=0; i<n; i++) ta[i]=a[i];
for(int i=n; i<N; i++) ta[i]=0;
for(int i=0; i<m; i++) tb[i]=b[i];
for(int i=n; i<N; i++) tb[i]=0;
for(int i=0; i<N; i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
NTT(ta,N,1);
NTT(tb,N,1);
for(int i=0; i<N; i++) c[i]=ta[i]*tb[i]%mod;
NTT(c,N,-1);
for(int i=n+m-1; i<N; i++) c[i]=0;
}
void poly_inv(ll *f,ll *g,int n) {
static ll tmp[maxn+5];
if(n==1){
g[0]=inv(f[0]);
return;
}
poly_inv(f,g,(n+1)/2);
poly_mul(f,g,tmp,n,n);
poly_mul(tmp,g,tmp,n,n);
for(int i=0;i<n;i++) g[i]=(2*g[i]-tmp[i]+mod)%mod;
}
void poly_deriv(ll *f,ll *g,int n){
for(int i=1;i<n;i++) g[i-1]=f[i]*i%mod;
g[n-1]=0;
}
void poly_inter(ll *f,ll *g,int n){
for(int i=n-1;i>=1;i--) g[i]=f[i-1]*inv(i)%mod;
g[0]=0;
}
void poly_ln(ll *f,ll *g,int n){
static ll invf[maxn+5];
poly_deriv(f,g,n);
poly_inv(f,invf,n);
poly_mul(invf,g,invf,n,n);
poly_inter(invf,g,n*2);
for(int i=n;i<n*2;i++) g[i]=0;
}
void poly_exp(ll *f,ll *g,int n){
static ll lng[maxn+5];
if(n==1){
g[0]=1;
return;
}
poly_exp(f,g,(n+1)/2);
poly_ln(g,lng,n);
for(int i=0;i<n;i++) lng[i]=(f[i]-lng[i]+mod)%mod;
lng[0]++;
poly_mul(g,lng,g,n,n);
for(int i=n;i<n*2;i++) g[i]=0;
}
void poly_pow(ll *f,ll *g,ll k,int n){
static ll tmpf[maxn+5];
for(int i=0;i<n;i++) g[i]=f[i];
poly_ln(g,tmpf,n);
for(int i=0;i<n;i++) tmpf[i]=tmpf[i]*k%mod;;
poly_exp(tmpf,g,n);
}
ll f[maxn+5],g[maxn+5];
int n,m;
int main(){
static ll tmp[maxn+5];
int x;
scanf("%d %d",&n,&m);
for(int i=1;i<=m;i++){//(x/g(x))^n=1/(g(x)/x)^n ������������һ�γ˷�
scanf("%d",&x);
f[x]=mod-1;
}
f[1]=1;
for(int i=0;i<n;i++) f[i]=f[i+1];//��x
// f[n-1]=0;
poly_pow(f,tmp,n,n);
poly_inv(tmp,g,n);
ll ans=g[n-1]*inv(n)%mod;
printf("%lld
",ans+1);
}