zoukankan      html  css  js  c++  java
  • CF1010F

    CF1010F Tree [* easy]

    给定一棵根节点为 (1) 的二叉树 (T),你需要先保留一个包含 (1) 号节点的连通块,然后给每个点确定一个权值 (a_i),使得对于每个点 (u) 都有其权值 (a_u) 大于等于其所有儿子的权值和 (sum a_v[(u,v)in T])

    最后,你需要使得根节点权值为 (m),求方案数,答案对 (998244353) 取模。

    (nle 10^5,mle 10^{18})

    Solution

    假设在 (i) 处放一个权值,那么其所有祖先都要放一个权值。

    于是我们假设确定了每个点的 "额外权值" 那么合法等价于 (m) 大于等于额外权值和。

    我们可以这样考虑,假设以 (1) 为根的连通块有 (L) 个点,那么此时的方案数显然就是 (frac{1}{(1-x)^{L}}[x^m])(inom{L+m-1}{m})

    于是我们只需要知道以 (1) 为根的,大小为 (i) 的连通块有多少个,使用生成函数来刻画答案,那么转移形如:

    [x(L(x)+1)(R(x)+1) ]

    暴力背包,(mathcal O(n^2))(树背包复杂度)

    接下来考虑优化,我们先找到一条重链,然后对忽略此重链的树递归,对于一条重链,假设每个点都有另一个儿子(没有将多项式设为 (0))此时我们相当于计算,给定一个序列和多项式 (F_2(x),F_3(x)...)(注意 (F_1(x)) 为空)(方便起见给所有多项式先加 (1),然后给 (F_2(x)) 乘以 ((x+1))),求:

    [(((F_2(x)x+1)F_3(x)x+1)F_4(x)x+1)... ]

    方便起见令 (G_i(x)=F_i(x)x),那么就有:

    [((G_2(x)+1)G_3(x)+1)... ]

    不难发现这个算式相当于将 (G) 翻转后统计 (sum_i prod_{jle i}G_j(x))

    使用分治 NTT 加速即可,复杂度为 (mathcal O(nlog^3 n))

    复杂度分析:

    (T(n)) 表示复杂度,则最后一次合并的复杂度为 (mathcal O(nlog^2 n))

    接下来,对于每棵子树,由于重链剖分,问题规模至少缩小了一半,且问题规模和仍然是 (n),于是递归层数至多为 (log n),每层复杂度仍为 (mathcal O(sum ( extrm{size})log^2 ( extrm{size}))=mathcal O(nlog^2 n)) 总复杂度为 (mathcal O(nlog^3 n))

    (Code:)

    #include<bits/stdc++.h>
    using namespace std ;
    #define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
    #define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
    #define Rep(i, s, t) for(register int i = (s); i < (t); ++ i)
    #define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
    #define re register
    #define mp make_pair
    #define pi pair<int, int>
    #define pb push_back
    #define int long long
    #define vi vector<int>
    int gi() {
    	char cc = getchar() ; int cn = 0, flus = 1 ;
    	while( cc < '0' || cc > '9' ) {  if( cc == '-' ) flus = - flus ; cc = getchar() ; }
    	while( cc >= '0' && cc <= '9' )  cn = cn * 10 + cc - '0', cc = getchar() ;
    	return cn * flus ;
    }
    const int N = 4e5 + 5 ; 
    const int P = 998244353 ; 
    const int G = 3 ; 
    const int Gi = 332748118 ; 
    int fpow(int x, int k) {
    	int ans = 1, base = x ;
    	while(k) {
    		if(k & 1) ans = 1ll * ans * base % P ;
    		base = 1ll * base * base % P, k >>= 1 ;
    	} return ans ;
    }
    int n, X, ch[N], sz[N], deg[N], R[N], fa[N], fac[N], inv[N], ind[N], L, limit, Inv ; 
    vector<int> F[N], E[N] ; 
    void init(int x) {
    	limit = 1, L = 0 ; while( limit < x ) limit <<= 1, ++ L ; 
    	Rep(i, 0, limit) R[i] = (R[i >> 1] >> 1) | ((i & 1) << (L - 1)) ; 
    	Inv = fpow(limit, P - 2) ; 
    }
    void NTT(vi& a, int type) {
    	Rep(i, 0, limit) if(R[i] > i) swap( a[i], a[R[i]] ) ; 
    	for(re int k = 1; k < limit; k <<= 1) {
    		int d = fpow( (type) ? G : Gi, (P - 1) / (k << 1) ) ;
    		for(re int i = 0; i < limit; i += (k << 1))
    		for(re int j = i, g = 1; j < i + k; ++ j, g = g * d % P) {
    			int nx = a[j], ny = a[j + k] * g % P ; 
    			a[j] = (nx + ny) % P, a[j + k] = (nx - ny + P) % P ; 
    		}
    	}
    	if( !type ) Rep(i, 0, limit) a[i] = a[i] * Inv % P ; 
    }
    void dfs1(int x, int ff) {
    	sz[x] = 1, fa[x] = ff ; 
    	for(int v : E[x]) if(v ^ ff) {
    		dfs1(v, x), sz[x] += sz[v], ++ deg[x] ;
    		if( sz[v] >= sz[ch[x]] ) ch[x] = v ; 
    	}
    }
    vector<int> p[N], f[N], st[N] ; int top ; 
    vi operator + (vi a, vi b) {
    	vi ans ; int cnt = max(a.size(), b.size()) ; 
    	ans.resize(cnt), a.resize(cnt), b.resize(cnt) ; 
    	Rep(i, 0, cnt) ans[i] = a[i] + b[i] ; 
    	return ans ; 
    }
    vi operator * (vi a, vi b) {
    	vi ans ; init(a.size() + b.size() + 2) ; int cnt = a.size() + b.size() - 1 ; 
    	ans.resize(limit), a.resize(limit), b.resize(limit) ; 
    	NTT(a, 1), NTT(b, 1) ;
    	Rep( i, 0, limit ) ans[i] = a[i] * b[i] % P ; 
    	NTT(ans, 0), ans.resize(cnt) ; 
    	return ans ; 
    }
    void Solve(int l, int r) {
    	if(l == r) { p[l] = st[l], f[l] = st[l], ++ f[l][0] ; return ; }
    	int mid = (l + r) >> 1 ; 
    	Solve(l, mid), Solve(mid + 1, r) ; 
    	vi fl = f[l], pr = p[mid + 1], fr = f[mid + 1] ; 
    	-- fl[0], f[l] = (fl * pr + fr), p[l] = p[l] * p[mid + 1] ; 
    }
    void count(int x) {
    	Solve(1, top), F[x] = f[1] ; top = 0 ; 
    }
    void solve(int x, int ff) {
    	if( deg[x] <= 1 ) F[x].resize(1), F[x][0] = 1 ; 
    	for(int v : E[x]) {
    		if(v == fa[x] || v == ch[x]) continue ; 
    		solve(v, v), F[x] = F[v] ; 
    	}
    	if( ch[x] ) solve(ch[x], x) ; 
    	int cnt = F[x].size() ; F[x].resize(cnt + 1) ; 
    	drep( i, 1, cnt ) F[x][i] = F[x][i - 1] ; 
    	F[x][0] = 0, st[++ top] = F[x] ; 
    	if( x == ff ) count(x) ; 
    	cnt = F[x].size() ; 
    }
    signed main()
    {
    	n = gi(), X = gi() ; int x, y ; 
    	rep( i, 2, n ) x = gi(), y = gi(), E[x].pb(y), E[y].pb(x) ; 
    	dfs1(1, 1), solve(1, 1) ; 
    	int Ans = 0 ; fac[0] = inv[0] = ind[0] = 1 ; 
    	rep( i, 1, n ) fac[i] = fac[i - 1] * i % P ;  
    	rep( i, 1, n ) inv[i] = fpow( fac[i], P - 2 ) ; 
    	rep( i, 1, n ) ind[i] = ind[i - 1] * ((X + i) % P) % P ; 
    	rep( i, 0, n - 1 ) Ans = (Ans + ind[i] * inv[i] % P * F[1][i + 1] % P) % P ; 
    	cout << Ans << endl ; 
    	return 0 ;
    }
    
  • 相关阅读:
    mysql 错误 1067: 进程意外终止
    VPS主机MSQL意外中断重启就好但10来个小时又中断的了如些反复
    使用hibernate连接mysql自动中断的问题
    40个国外联盟
    从服务里删除mysql
    外国广告联盟[16个]
    stm32学习笔记:GPIO外部中断的使用
    NO.2 设计包含min 函数的栈
    GPS数据,实测
    LATEX使用总结
  • 原文地址:https://www.cnblogs.com/Soulist/p/14027920.html
Copyright © 2011-2022 走看看