一类经典问题的解法
前言
在做这道题时,知道了有这样一种神奇的解法,然后就学了一发。
问题
给定(t(1 leq t leq 1e5)),(n(1 leq n leq 1e5)),计算所有的(sum_{i=1}^{n} a_i^{k}(1 leq k leq t))。
解析
先考虑计算一个式子
[f(x)=prod_{i=1}^{n}(1+a_{i}x)
]
这个式子可以用分治(O(n log^{2}{n}))求。
现在我们对这个式子取自然对数,那么有:
[ln(f(x))=sum_{i=1}^{n}ln(1+a_{i}x)
]
再求导,那么有:
[ln'(f(x))=sum_{i=1}^{n}ln'(1+a_{i}x)=sum_{i=1}^{n}frac{a_{i}}{1+a_{i}x}
]
观察发现求和式里面的式子其实是无穷递缩等比数列的求和公式,有:
[frac{a_{i}}{1+a_{i}x}=sum_{j=0}^{infty} (-1)^{j}a_{i}^{j+1}x^j
]
那么上式有:
[上式=sum_{i=1}^{n} sum_{j=0}^{infty} (-1)^{j}a_{i}^{j+1}x^j=sum_{j=1}^{infty} (-1)^{j} igl[sum_{i=1}^{n} a_{i}^{j+1}igr] x^j
]
现在这个式子的奇系数再取反,就得到了我们想要的式子。
总结一下,就是先NTT分治算出(f(x)=prod_{i=1}^{n}(1+a_{i}x)),再取对数求导,最后奇系数取反,复杂度为(O(nlog^{2}{n}))。
应用
回到上面的问题,我们要求的就是:
[sum_{i=1}^{n} sum_{j=1}^{m} (a_{i}+b_{j})^{k}$$,当然最后要乘上$nm$的逆元。
考虑二项式展开,有:
$$sum_{i=1}^{n} sum_{j=1}^{m} sum_{t=0}^{k} {k choose t} a_{i}^{t} b_{j}^{k-t}]
交换枚举顺序,有:
[sum_{t=0}^{k} {k choose t} sum_{i=1}^{n} sum_{j=1}^{m} a_{i}^{t} b_{j}^{k-t}=sum_{t=0}^{k} {k choose t} igl [sum_{i=1}^{n} a_{i}^t igr] igl[sum_{j=1}^{m} b_{j}^{k-t} igr]=k! imes sum_{t=0}^{k} frac{sum_{i=1}^{n} a_{i}^t}{t!}frac{sum_{j=1}^{m} b_{j}^{k-t}}{(k-t)!}
]
你会发现这就是两个上面所述的多项式乘上相应的系数后的卷积,复杂度(O(n log^{2}{n}))。
代码
#include<bits/stdc++.h>
#define N 300005
#define mid ((l+r)>>1)
using namespace std;
template<typename T> inline void In(T& x){
char c=getchar(); int ft=1;
for(x=0;c<'0'||c>'9';c=getchar()) if(c=='-') ft=-1;
for(;c>='0'&&c<='9';c=getchar()) x=x*10+c-'0';
x*=ft;
}
template<typename T,typename... U>
inline void In(T& x,U& ... y){
In(x); In(y...);
}
const int P=998244353,g=3;
int power(int x,int k){
if(!x) return 0;
int s=1,t=x;
for(;k;k>>=1,t=1ll*t*t%P)
if(k&1) s=1ll*s*t%P;
return s;
}
int n,m,K,L,C,inv_nm;
int ali[N/3],bob[N/3],fac[N/3],inv[N/3],r[N];
vector<int> a,b;
inline void Init(){
fac[0]=1; for(int i=1;i<=K;++i) fac[i]=1ll*i*fac[i-1]%P;
inv[K]=power(fac[K],P-2); for(int i=K-1;~i;--i) inv[i]=1ll*(i+1)*inv[i+1]%P;
}
inline void Output(vector<int> A){
int len=A.size();
for(int i=0;i<len;++i) printf("%d%c",A[i],i==len-1?'
':' ');
}
inline void NTT_prepare(int len){
L=1; C=0; while(L<=len) L<<=1,++C;
for(int i=1;i<L;++i) r[i]=(r[i>>1]>>1)|((i&1)<<(C-1));
}
void NTT(vector<int>& A,int op){
for(int i=0;i<L;++i) if(i<r[i]) swap(A[i],A[r[i]]);
for(int i=1;i<L;i<<=1){
int Wn=power(g,(P-1)/(i<<1));
if(op==-1) Wn=power(Wn,P-2);
for(int j=0;j<L;j+=(i<<1)){
int w=1;
for(int k=0;k<i;++k,w=1ll*w*Wn%P){
int p=A[j+k],q=1ll*w*A[i+j+k]%P;
A[j+k]=(p+q)%P; A[i+j+k]=((p-q)%P+P)%P;
}
}
}
if(op==-1){
int inv_L=power(L,P-2);
for(int i=0;i<L;++i)
A[i]=1ll*A[i]*inv_L%P;
}
}
vector<int> Mul(vector<int> A,vector<int> B) {
int len=A.size()+B.size()-1; NTT_prepare(len);
A.resize(L); NTT(A,1); B.resize(L); NTT(B,1);
for(int i=0;i<L;++i) A[i]=1ll*A[i]*B[i]%P;
NTT(A,-1); A.resize(len); return A;
}
vector<int> Inv(vector<int> A){
int len=A.size();
if(len==1) return vector<int>(1,power(A[0],P-2));
vector<int> B=Inv(vector<int>(A.begin(),A.begin()+(len+1)/2)),C=A;
NTT_prepare(len<<1); B.resize(L); C.resize(L); NTT(B,1); NTT(C,1);
for(int i=0;i<L;++i) B[i]=1ll*(2-1ll*C[i]*B[i]%P+P)%P*B[i]%P;
NTT(B,-1); B.resize(len); return B;
}
vector<int> Der(vector<int> A){
int len=A.size(); vector<int> B(len-1);
for(int i=0;i<len-1;++i) B[i]=1ll*(i+1)*A[i+1]%P;
return B;
}
vector<int> Int(vector<int> A){
int len=A.size(); vector<int> B(len+1);
for(int i=1;i<=len;++i) B[i]=1ll*power(i,P-2)*A[i-1]%P;
return B;
}
vector<int> ln(vector<int> A){
int len=A.size();
A=Int(Mul(Der(A),Inv(A)));
return A;
}
vector<int> Solve(int l,int r,int* A){
if(l==r){ vector<int> B(2); B[0]=1; B[1]=A[l]; return B; }
return Mul(Solve(l,mid,A),Solve(mid+1,r,A));
}
inline vector<int> Wan(vector<int> A){
int len=A.size();
for(int i=1;i<A.size();i+=2) A[i]=(P-A[i])%P;
return A;
}
int main(){
In(n,m);
for(int i=1;i<=n;++i) In(ali[i]);
for(int i=1;i<=m;++i) In(bob[i]);
In(K); Init();
a=Solve(1,n,ali); a.resize(K+1); a=Der(ln(a));
a=Wan(a); a.resize(K+1); for(int i=K;i;--i) a[i]=a[i-1]; a[0]=n;
for(int i=0;i<=K;++i) a[i]=1ll*inv[i]*a[i]%P;
b=Solve(1,m,bob); b.resize(K+1); b=Der(ln(b));
b=Wan(b); b.resize(K+1); for(int i=K;i;--i) b[i]=b[i-1]; b[0]=m;
for(int i=0;i<=K;++i) b[i]=1ll*inv[i]*b[i]%P;
a=Mul(a,b); a.resize(K+1); inv_nm=power(1ll*n*m%P,P-2);
for(int i=0;i<=K;++i) a[i]=1ll*fac[i]*a[i]%P*inv_nm%P;
for(int i=1;i<=K;++i) printf("%d
",a[i]);
return 0;
}