zoukankan      html  css  js  c++  java
  • 【LOJ】#2320. 「清华集训 2017」生成树计数

    题解

    我,理解题解,用了一天
    我,卡常数,又用了一天
    到了最后,我才发现,我有个加法取模,写的是while(c >= MOD) c -= MOD
    我把while改成if,时间,少了
    六倍。
    六倍。
    六倍!!!!

    maya我又用第一次T的代码改掉了while,我第一次T的代码也A了= =

    那我,改单位复根,FFT循环展开,分治内部循环展开,为了啥= =

    好吧,但是我最后上榜了。。。LOJ第四的样子。。

    (prod_{i = 1}^{N} d_{i}^{m}sum_{i = 1}^{N}d_{i}^{m})
    这个时候,我们把连通块当成一个大点做prufer序,序列里的每个连通块(i)的位置都有(a_{i})种填数方式,式子可以改写成这样
    (prod_{i = 1}^{N} a_{i}^{d_{i}}d_{i}^{m}sum_{i = 1}^{N}d_{i}^{m})
    (sum_{i = 1}^{N} d_{i}^{m}prod_{j = 1}^{N}a_{j}^{d_{j}}d_{j}^{m})

    我们考虑一下一个最暴力的dp(我的第一反应,啥,这怎么是dp?

    当然是dp出一个prufer序列啦

    (f[i][j])表示考虑了前i个数,填了j个格子,没有多乘任何一个(d_{i}^{m})
    (g[i][j])表示考虑了前i个数,填了j个格子,已经乘了一个(d_{i}^{m})

    转移的时候,枚举当前这个数填上k个格子

    (f[i][j] += f[i - 1][j - k] * a_{i}^{k} k^{m})
    (g[i][j] += g[i - 1][j - k] * a_{i}^{k} k^{m})
    (g[i][j] += f[i - 1][j - k] * a_{i}^{k} k^{2m})
    复杂度(O(n^{3}))

    但是显然过不掉

    我们再考虑一个……套路!因为m出奇的小?
    乘方转斯特林数
    想一下(x^{m})的组合意义,相当于(m)种颜色涂在(x)个格子里,每个种颜色只能涂一个格子,但是每个格子可以涂很多颜色

    再看一下原式子,把它变成这样的形式
    (sum_{i = 1}^{N} a_{i}^{d_{i}}d_{i}^{2m}prod_{j = 1,j != i}^{N}a_{j}^{d_{j}}d_{j}^{m})

    然后,我们把乘方拆开,怎么拆呢
    考虑到上述的组合意义
    我们也就是求所有prufer序列的染色方式,在每一次决策的时候选择是否将某个点用(2m)种颜色染,我们在序列后面(脑补)出1 - N,这样prufer序列里数字出现的个数就正好是点度

    (m)很小,染色的格子很少,我们从这个方向考虑
    设已经被染色的格子个数是(j),那么再放入一个新的数,染色个数为(k)的时候,需要乘上
    (inom{n - 2 - j}{k}((S(m,k)k! + S(m,k + 1)(k + 1)!))
    我们就是在(n - 2 - j)个没有染色的位置里,选择(k)个位置,然后(S(m,k))就相当于把这(m)种颜色分配到这(k)个位置里,然后这些颜色还可以再次打乱顺序进行排列,就再乘一个(k!)
    为什么还有(k + 1),因为(k)代表的是这(n - 2)个prufer序里面的这类点染色的个数,我们还有后面脑补出的1 - N的位置,所以要讨论+1的情况
    这个时候我们已经给每个连通块选好了(d_{i})个点,并涂好了颜色,然后我们再乘上式子里的(a_{i}^{d_{i}})就好
    如果是(2m)种颜色染色,把(m)换成(2m)就可以

    我们观察一下这个组合数,我们会发现我们可以把(frac{1}{k!})分离出来,每次统计贡献的时候加上,剩下的等全部算完答案后,如果染色的个数为(j),那么答案再乘上(frac{(n - 2)!}{(n - 2 - j)!})
    同时,我们除了染色的位置还有很多空位,每个空位都有(sum_{i = 1}^{n}a_{i})种情况,如果有(j)个位置被染色了,那么答案还要再乘上
    ((sum_{i = 1}^{n}a_{i})^{n - 2 - j})
    这样的复杂度是(n^2m)

    什么,你觉得想到这个有点别扭?
    那我写一个更套路一点的想法吧

    我们回忆到斯特林数的展开式,根据容斥,显然有
    (S(m,k) = frac{1}{k!} sum_{i = 0}^{k}(-1)^{i} inom{k}{i} (k - i)^{m})
    什么,你说你不知道这个式子……简单理解一下,就是我们枚举(i)个盒子里没有任何元素,(inom{k}{i})枚举了所有(i)个盒子的集合……就有了这个容斥
    我们稍稍的变一下形
    (S(m,k) k!= sum_{i = 0}^{k}(-1)^{k - i} inom{k}{i} (i)^{m})
    有没有注意到啥。。这个式子无比的眼熟
    如果还没注意到就设
    (f(k) = S(m,k) k!)
    (g(k) = k^{m})
    (f(k) = sum_{i = 0}^{k}(-1)^{k - i} inom{k}{i} g(i))
    如果你还不为所动……我就打人了= =

    好吧,这很明显是个二项式反演,我们尝试用(f(x))表示(g(x))

    (g(x) = sum_{i = 0}^{x}inom{x}{i}f(i))

    我们尝试把点度(d)代入

    (d^{m} = sum_{k = 0}^{d} inom{d}{k} S(m,k) k!)
    (d)可能很大,是没错,但是(m)很小,在这里,我们假设(m < d),因为对于(S(m,k))如果(k > m)那么(S(m,k) = 0)

    (d^{m} = sum_{k = 0}^{m} inom{d}{k} S(m,k) k!)
    那么这个方程是什么呢,代表了(d^{m})可以用在(d)位置里选(k)个位置然后乘上(S(m,k)k!)来组合出来

    嗯?和(O(n^3))的算法有点类似了?
    但是和那个暴力DP不同的是,我们只需要枚举(m)个位置了,复杂度降到了(O(n^{2}m))

    我们只需要考虑如何用dp覆盖对于每一种prufer序列我们对每个连通块的大点,每种点都选(0 - m)个位置,来乘上(S(m,k)k!)

    这个时候,我们想到我们记录染色的位置,来求出所有本质不同的染色序列,这就是我们的答案了

    会不会有人到这里就不理解(S(m,k + 1)(k + 1)!)了啊……,因为prufer序就是前面n - 2个数决定的,不关后面(脑补)的1 - N个数的顺序
    所以我们要单独讨论一下,可以认为是我们对于这个式子
    (d^{m} = sum_{k = 0}^{m} inom{d}{k} S(m,k) k!)
    假如现在k = 0,我们的(S(m,k + 1)(k + 1)!)是统计了k = 1时的某些情况
    假如现在k = 1,我们的(S(m,k + 1)(k + 1)!)是统计了k = 2时的某些情况

    (f,g)的定义和前面的暴力dp定义类似
    (f[i][j])表示考虑了前(i)个点,有(j)个位置被染色了
    (g[i][j])表示考虑了前(i)个点,有(j)个位置被染色,还同时统计了某一处位置用(2m)染色了

    给出一段暴力的代码吧……

    int cur = 0;
    f[0][0] = 1;
    for(int i = 1 ; i <= N ; ++i) {
    	memset(f[cur ^ 1],0,sizeof(f[cur ^ 1]));
    	memset(g[cur ^ 1],0,sizeof(g[cur ^ 1]));
    	for(int j = 0 ; j <= N - 2 ; ++j) {
    	    for(int k = 0 ; k <= M && j + k <= N - 2 ; ++k) {
    			update(f[cur ^ 1][j + k],mul(f[cur][j],F1[i][k]));
    			update(g[cur ^ 1][j + k],mul(g[cur][j],F1[i][k]));
    	    }
    	    for(int k = 0 ; k <= 2 * M  && j + k <= N - 2; ++k) {
    			update(g[cur ^ 1][j + k],mul(f[cur][j],F2[i][k]));
    	    }
    	}
    	cur ^= 1;
    }
    

    (F1,F2)分别代表是用一个(m)染色还是用2个(m)染色

    似乎没有什么可优化啦……
    然而这个方程本质是个卷积,可以上分治FFT,然后就在(O(nm log n))解决了。。。

    他说不卡常……确实不卡常,我的错误神奇得太离谱了= =

    正常模样的代码

    #include <iostream>
    #include <cstdio>
    #include <vector>
    #include <algorithm>
    #include <cmath>
    #include <cstring>
    #include <map>
    //#define ivorysi
    #define pb push_back
    #define space putchar(' ')
    #define enter putchar('
    ')
    #define mp make_pair
    #define pb push_back
    #define fi first
    #define se second
    #define mo 974711
    #define MAXN 30005
    using namespace std;
    typedef long long int64;
    typedef double db;
    template<class T>
    void read(T &res) {
        res = 0;char c = getchar();T f = 1;
        while(c < '0' || c > '9') {
    	if(c == '-') f = -1;
    	c = getchar();
        }
        while(c >= '0' && c <= '9') {
    	res = res * 10 + c - '0';
    	c = getchar();
        }
        res *= f;
    }
    template<class T>
    void out(T x) {
        if(x < 0) {putchar('-');x = -x;}
        if(x >= 10) {
    	out(x / 10);
        }
        putchar('0' + x % 10);
    }
    const int MOD = 998244353,G = 3,L = (1 << 20);
    int fac[MAXN],S[75][75],invfac[MAXN],N,M,a[MAXN],P[MAXN][75],F1[MAXN][75],F2[MAXN][75];
    int f[2][MAXN],g[2][MAXN],W[L + 5];
    int mul(int a,int b) {return 1LL * a * b % MOD;}
    int inc(int a,int b) {a = a + b;if(a >= MOD) a -= MOD;return a;}
    void update(int &x,int y) {x = inc(x,y);}
    int fpow(int x,int c) {
        int res = 1,t = x;
        while(c) {
    		if(c & 1) res = mul(res,t);
    		t = mul(t,t);
    		c >>= 1;
        }
        return res;
    }
    struct poly {
    	vector<int> a;
    	poly() {a.clear();}
    	friend void NTT(poly &f,int T,int on) {
    		f.a.resize(T);
    		for(int i = 1 , j = T / 2; i < T - 1; ++i) {
    			if(i < j) swap(f.a[i],f.a[j]);
    			int k = T / 2;
    			while(j >= k) {j -= k;k >>= 1;}
    			j += k;
    		}
    		for(int h = 2 ; h <= T ; h <<= 1) {
    			int wn = W[(L + on * L / h) % L];
    			for(int k = 0 ; k < T ; k += h) {
    				int w = 1;
    				for(int j = k ; j < k + h / 2 ; ++j) {
    					int u = f.a[j],t = mul(f.a[j + h / 2],w);
    					f.a[j] = inc(u,t);
    					f.a[j + h / 2] = inc(u,MOD - t);
    					w = mul(w,wn);
    				}
    			}
    		}
    		if(on == -1) {
    			int InvT = fpow(T,MOD - 2);
    			for(int i = 0 ; i < T ; ++i) f.a[i] = mul(f.a[i],InvT);
    		}
    	}
    	friend poly operator + (const poly &f,const poly &g) {
    		int T = max(f.a.size(),g.a.size());
    		poly h;h.a.resize(T);
    		for(int i = 0 ; i < T ; ++i) h.a[i] = inc(f.a[i],g.a[i]);
    		return h;
    	}
    	friend poly operator * (poly f,poly g) {
    		int T = 1,t = f.a.size() + g.a.size();
    		while(T <= t) T <<= 1;
    		NTT(f,T,1);NTT(g,T,1);
    		poly h;h.a.resize(T);
    		for(int i = 0 ; i < T ; ++i) h.a[i] = mul(f.a[i],g.a[i]);
    		NTT(h,T,-1);
    		for(int i = T - 1 ; i >= 0 ; --i) {
    			if(!h.a[i]) h.a.pop_back();
    			else break;
    		}
    		if(h.a.size() > N - 1) h.a.resize(N - 1);
    		return h;
    	}
    };
    void Init() {
        read(N);read(M);
        int T = max(N,2 * M + 1);
        fac[0] = 1;
        for(int i = 1 ; i <= T ; ++i) fac[i] = mul(fac[i - 1],i);
        invfac[T] = fpow(fac[T],MOD - 2);
        for(int i = T - 1; i >= 0 ; --i) invfac[i] = mul(invfac[i + 1],i + 1);
        S[0][0] = 1;
        for(int i = 1 ; i <= 70 ; ++i) {
    		for(int j = 1 ; j <= i ; ++j) {
    		    S[i][j] = inc(S[i - 1][j - 1],mul(S[i - 1][j],j));
    		}
        }
        
        for(int i = 1 ; i <= N ; ++i) read(a[i]);
        for(int i = 1 ; i <= N ; ++i) {
    		P[i][0] = 1;
    		for(int j = 1 ; j <= 70 ; ++j) {
    		    P[i][j] = mul(P[i][j - 1],a[i]);
    		}
        }
        W[0] = 1,W[1] = fpow(G,(MOD - 1) / L);
        for(int i = 2 ; i < L ; ++i) W[i] = mul(W[i - 1],W[1]);
    }
    pair<poly,poly> Solve(int L,int R) {
    	if(L == R) {
    		poly s,t;s.a.resize(M + 1);t.a.resize(2 * M + 1);
    		for(int i = 0 ; i <= 2 * M ; ++i) {
    			if(i <= M) 
    				s.a[i] = mul(mul(inc(mul(S[M][i],fac[i]),mul(S[M][i + 1],fac[i + 1])),invfac[i]),P[L][i]);
    			t.a[i] = mul(mul(inc(mul(S[2 * M][i],fac[i]),mul(S[2 * M][i + 1],fac[i + 1])),invfac[i]),P[L][i]);
    		}
    		return mp(s,t);
    	}
    	int mid = (L + R) >> 1;
    	pair<poly,poly> S = Solve(L,mid),T = Solve(mid + 1,R);
    	return mp(S.fi * T.fi,S.se * T.fi + S.fi * T.se);
    }
    int main() {
    #ifdef ivorysi
        freopen("f1.in","r",stdin);
    #endif
        Init();
        pair<poly,poly> res = Solve(1,N);
        poly g = res.se;g.a.resize(N - 1);
        int sum = 0,p = 1,ans = 0,t = 1;
        for(int i = 1 ; i <= N ; ++i) sum = inc(sum,a[i]),p = mul(p,a[i]);
    	for(int i = N - 2 ; i >= 0; --i) {
    		ans = inc(ans,mul(mul(g.a[i],t),invfac[N - 2 - i]));
    		t = mul(t,sum);
        }
        ans = mul(ans,mul(fac[N - 2],p));
        out(ans);putchar('
    ');
        return 0;
    }
    

    卡常之后的代码

    #include <iostream>
    #include <cstdio>
    #include <vector>
    #include <algorithm>
    #include <cmath>
    #include <cstring>
    #include <map>
    //#define ivorysi
    #define pb push_back
    #define space putchar(' ')
    #define enter putchar('
    ')
    #define mp make_pair
    #define pb push_back
    #define fi first
    #define se second
    #define mo 974711
    #define MAXN 400005
    #define RG register
    using namespace std;
    typedef long long int64;
    typedef double db;
    template<class T>
    void read(T &res) {
        res = 0;char c = getchar();T f = 1;
        while(c < '0' || c > '9') {
    	if(c == '-') f = -1;
    	c = getchar();
        }
        while(c >= '0' && c <= '9') {
    	res = res * 10 + c - '0';
    	c = getchar();
        }
        res *= f;
    }
    template<class T>
    void out(T x) {
        if(x < 0) {putchar('-');x = -x;}
        if(x >= 10) {
    	out(x / 10);
        }
        putchar('0' + x % 10);
    }
    const int MOD = 998244353,L = (1 << 16);
    int fac[MAXN],S[105][105],invfac[MAXN],a[MAXN],N,M,P[MAXN][105];
    int W[L + 5],IW[L + 5],top,fl[MAXN],fr[MAXN],gl[MAXN],gr[MAXN],f[MAXN],g[MAXN];
    vector<int> F[105],G[105];
    inline int mul(RG const int &a,RG const int &b) {return (int64)a * b % MOD;}
    inline int inc(RG const int &a,RG const int &b) {RG int c = a + b;if(c >= MOD) c -= MOD;return c;}
    int fpow(RG int x,RG int c) {
        RG int res = 1,t = x;
        while(c) {
    	if(c & 1) res = mul(res,t);
    	t = mul(t,t);
    	c >>= 1;
        }
        return res;
    }
    void NTT(RG int *f,RG int T,RG int on,RG int *w) {
        RG int tmp,*wm,*ai,*ami;
        for(RG int i = 1 , j = T / 2; i < T - 1 ; ++i) {
    	if(i < j) {tmp = f[i];f[i] = f[j];f[j] = tmp;}
    	RG int k = T / 2;
    	while(j >= k) {
    	    j -= k;
    	    k >>= 1;
    	}
    	j += k;
        }
        
        #define work(j) {tmp = mul(ami[j],wm[j]);ami[j] = inc(ai[j],MOD - tmp);ai[j] = inc(ai[j],tmp); }
        for(RG int h = 2,m = 1; h <= T ; m = h,h <<= 1) {
    	wm = w + m;
    	if(m < 8) {
    	    for(RG int i = 0 ; i < T ; i += h) {
    		ai = f + i,ami = f + m + i;
    		for(RG int j = 0 ; j < m ; ++j) work(j);
    	    }
    	}
    	else {
    	    for(RG int i = 0 ; i < T ; i += h) {
    		ai = f + i,ami = f + m + i;
    		for(RG int j = 0 ; j < m ; j += 8) {
    		    work(j);
    		    work(j + 1);
    		    work(j + 2);
    		    work(j + 3);
    		    work(j + 4);
    		    work(j + 5);
    		    work(j + 6);
    		    work(j + 7);
    		}
    	    }
    	}
        }
        if(on < 0) {
    	RG int InvT = fpow(T,MOD - 2);
            #define C(x,y) {f[x] = mul(f[x],y);}
    	if(T < 8) {for(RG int i = 0 ; i < T ; ++i) C(i,InvT);}
    	else {
    	    for(RG int i = 0 ; i < T ; i += 8) {
    		C(i,InvT);
    		C(i + 1,InvT);
    		C(i + 2,InvT);
    		C(i + 3,InvT);
    		C(i + 4,InvT);
    		C(i + 5,InvT);
    		C(i + 6,InvT);
    		C(i + 7,InvT);
    	    }
    	}
        }
    }
    void Init() {
        read(N);read(M);
        RG int T = max(N,2 * M + 1);
        fac[0] = 1;
        for(RG int i = 1 ; i <= T ; ++i) fac[i] = mul(fac[i - 1],i);
        invfac[T] = fpow(fac[T],MOD - 2);
        for(RG int i = T - 1; i >= 0 ; --i) invfac[i] = mul(invfac[i + 1],i + 1);
        S[0][0] = 1;
        for(RG int i = 1 ; i <= 70 ; ++i) {
    	for(RG int j = 1 ; j <= i ; ++j) {
    	    S[i][j] = inc(S[i - 1][j - 1],mul(S[i - 1][j],j));
    	}
        }
        for(RG int j = 1 ; j <= 2 * M ; ++j) {
        	S[M][j] = mul(S[M][j],fac[j]);
        	S[2 * M][j] = mul(S[2 * M][j],fac[j]); 
        }
        for(RG int j = 0 ; j <= 2 * M ; ++j) {
        	S[M][j] = mul(inc(S[M][j],S[M][j + 1]),invfac[j]);
        	S[2 * M][j] = mul(inc(S[2 * M][j],S[2 * M][j + 1]),invfac[j]);
        }
        for(RG int i = 1 ; i <= N ; ++i) read(a[i]);
        for(RG int i = 1 ; i <= N ; ++i) {
    	P[i][0] = 1;
    	for(RG int j = 1 ; j <= 70 ; ++j) {
    	    P[i][j] = mul(P[i][j - 1],a[i]);
    	}
        }
        RG int half = L / 2;
        RG int t1 = fpow(3,(MOD - 1) / L),t2 = fpow(t1,MOD - 2);
        W[half] = 1;IW[half] = 1;
        for(RG int i = 1 ; i < half; ++i) W[i + half] = mul(W[i + half - 1],t1),IW[i + half] = mul(IW[i + half - 1],t2);
        for(RG int i = half - 1 ; i >= 0 ; --i) W[i] = W[i << 1],IW[i] = IW[i << 1];
    }
    void Solve(RG int L,RG int R) {
        
        if(L == R) {
    	++top;
    	F[top].resize(M + 1);G[top].resize(2 * M + 1);
    	for(RG int i = 0 ; i <= 2 * M ; ++i) {
    	    if(i <= M) 
    		F[top][i] = mul(S[M][i],P[L][i]);
    	    G[top][i] = mul(S[2 * M][i],P[L][i]);
    	}
    	return ;
        }
        RG int mid = (L + R) >> 1;
        Solve(L,mid);int Ld = top; 
        Solve(mid + 1,R);int Rd = top;
        top -= 2;
        RG int s1 = F[Ld].size(),s2 = F[Rd].size(),s3 = G[Ld].size(),s4 = G[Rd].size();
        RG int t = max(max(s1 + s2,s1 + s4),s3 + s2);
        RG int K = 1;while(K <= t) K <<= 1;
        for(int i = 0 ; i < s1 ; ++i) fl[i] = F[Ld][i];
        for(int i = 0 ; i < s2 ; ++i) fr[i] = F[Rd][i];
        for(int i = 0 ; i < s3 ; ++i) gl[i] = G[Ld][i];
        for(int i = 0 ; i < s4 ; ++i) gr[i] = G[Rd][i];
        fill(fl + s1,fl + K,0);
        fill(fr + s2,fr + K,0);
        fill(gl + s3,gl + K,0);
        fill(gr + s4,gr + K,0);
        NTT(fl,K,1,W);NTT(fr,K,1,W);NTT(gl,K,1,W);NTT(gr,K,1,W);
    #define Calc1(i) {f[i] = mul(fl[i],fr[i]);}
    #define Calc2(i) {g[i] = inc(mul(fl[i],gr[i]),mul(gl[i],fr[i]));}
        if(K < 8) {
    	for(int i = 0 ; i < K ; ++i) {
    	    Calc1(i);Calc2(i);
    	}
        }
        else {
    	for(RG int i = 0 ; i < K ; i += 8) {
    	    Calc1(i);Calc1(i + 1);
    	    Calc1(i + 2);Calc1(i + 3);
    	    Calc1(i + 4);Calc1(i + 5);
    	    Calc1(i + 6);Calc1(i + 7);
    	    Calc2(i);Calc2(i + 1);
    	    Calc2(i + 2);Calc2(i + 3);
    	    Calc2(i + 4);Calc2(i + 5);
    	    Calc2(i + 6);Calc2(i + 7);
    	}
        }
        NTT(f,K,-1,IW);NTT(g,K,-1,IW);
        ++top;
        F[top].clear();G[top].clear();
        t = min(N - 2,K - 1);
        while(t >= 0) {
    	if(!f[t]) --t;
    	else break;
        }
        for(RG int i = 0 ; i <= t ; ++i) F[top].pb(f[i]);
        t = min(N - 2,K - 1);
        while(t >= 0) {
    	if(!g[t]) --t;
    	else break;
        }
        for(RG int i = 0 ; i <= t ; ++i) G[top].pb(g[i]);
    }
    int main() {
    #ifdef ivorysi
        freopen("f1.in","r",stdin);
    #endif
        Init();
        Solve(1,N);
        G[top].resize(N - 1);
        RG int ans = 0,sum = 0,p = 1,t = 1;
        for(RG int i = 1 ; i <= N ; ++i) sum = inc(sum,a[i]),p = mul(p,a[i]);
        for(RG int i = N - 2 ; i >= 0 ; --i) {
    	ans = inc(ans,mul(mul(G[top][i],t),invfac[N - 2 - i]));
    	t = mul(t,sum);
        }
        ans = mul(mul(ans,p),fac[N - 2]);
        out(ans);enter;
        //out(clock());enter;
        return 0;
    }
    
  • 相关阅读:
    PyCharm 安装package matplotlib为例
    Julia 下载 安装 juno 开发环境搭建
    进程 线程 协程
    Eclipse Golang 开发环境搭建 GoClipse 插件
    TaxonKit
    tar: Removing leading `/' from member names
    Linux 只列出目录的方法
    unbuntu 安装 teamviewer
    ubuntu 设置静态IP
    Spring 配置文件中 元素 属性 说明
  • 原文地址:https://www.cnblogs.com/ivorysi/p/9093086.html
Copyright © 2011-2022 走看看