zoukankan      html  css  js  c++  java
  • [题解] [AGC005F] Many Easy Problems

    题面

    题解

    学长讲课题目的质量果然和我平常找的那些不一样

    思路还是可以说比较巧妙的

    考虑到我们并不好算出对于所有大小为 (i) 的点集,能够包含它的最小连通块大小

    转换题目

    这个时候我们应该想到把目标放到单个点 (i) 对选择 (k) 个点时的贡献

    那么他的贡献就是总方案数减去没选的方案数对吧

    没选的方案怎么算呢

    (sz[i]) 为以 (i) 为根的子树的点数

    分析发现, 当 (k) 个点同时在以它的儿子为根的一棵子树内就不会计算 (i)

    这里的儿子的意义是, 以 (i) 为全树的根时它的儿子

    也就是说, 以它的父亲为根, 大小为 (n - sz[i]) 的子树也算作 (i) 的子树

    然后大力推式子

    [displaystyleegin{aligned}f(k) = sum_{u=1}^n(C_n^k-sum_{vin sonu}C_{sz[v]}^k)end{aligned} ]

    (C_n^k) 是总方案数, 后面那个就是不会计算 (i) 的方案数

    但是我们又发现, 这个东西不是很好算

    于是又一次转换, 我们设 (cnt_i) 为大小为 (i) 的子树的个数

    然后把相同大小的子树的贡献归集起来有

    [displaystyleegin{aligned}f(k)&=n*C_n^k - sum_{i=k}^ncnt_i*C_i^k\&=n*C_n^k-sum_{i=k}^{n}cnt_i*frac{i!}{k!*(i-k)!}\&=n*C_n^k-frac{1}{k!}sum_{i=k}^{n}frac{cnt_i*i!}{(i-k)!}end{aligned} ]

    (F_i = cnt_i*i!) , (G_i = frac{1}{i!}) , (H_i = G_{n-i})

    所以有

    [displaystyleegin{aligned}f(k)&=n*C_n^k-frac{1}{k!}sum_{i=k}^{n}F_i*H_{n+k-i}end{aligned} ]

    NTT 即可

    Code

    #include <algorithm>
    #include <iostream>
    #include <cstring>
    #include <cstdio>
    const int mod = 924844033;
    const int N = 200005; 
    using namespace std;
    
    int m, lim, n, g, ig, p[105], cnt, f[N * 8], inv[2 * N], fac[2 * N], h[N * 8], sz[N], head[N], rev[N * 8];
    struct edge { int to, nxt; } e[N << 1]; 
    
    template < typename T >
    inline T read()
    {
    	T x = 0, w = 1; char c = getchar();
    	while(c < '0' || c > '9') { if(c == '-') w = -1; c = getchar(); }
    	while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    	return x * w; 
    }
    
    inline void adde(int u, int v) { e[++cnt] = (edge) { v, head[u] }, head[u] = cnt; }
    
    int fpow(int x, int y)
    {
    	int res = 1;
    	for( ; y; y >>= 1, x = 1ll * x * x % mod)
    		if(y & 1) res = 1ll * res * x % mod;
    	return res; 
    }
    
    int getroot(int x)
    {
    	int rem = x - 1, tmp = x - 1;
    	for(int i = 2; i * i <= rem; i++)
    		if(!(tmp % i))
    		{
    			p[++cnt] = i;
    			while(!(tmp % i)) tmp /= i; 
    		}
    	if(tmp > 1) p[++cnt] = tmp;
    	for(int flag = 0, i = 2; i <= x; i++, flag = 0)
    	{
    		for(int j = 1; j <= cnt; j++)
    			if(fpow(i, rem / p[j]) == 1) flag = 1;
    		if(!flag) return i; 
    	}
    }
    
    void pre()
    {
    	for(int i = (fac[0] = 1); i <= 2 * n; i++)
    		fac[i] = 1ll * fac[i - 1] * i % mod; 
    	inv[2 * n] = fpow(fac[2 * n], mod - 2); 
    	for(int i = 2 * n - 1; i >= 0; i--)
    		inv[i] = 1ll * inv[i + 1] * (i + 1) % mod; 
    	for(int i = 1; i <= n; i++)
    		h[i] = inv[n - i]; 
    }
    
    int C(int n, int m) { return 1ll * fac[n] * inv[m] % mod * inv[n - m] % mod; }
    
    void dfs(int u, int fa)
    {
    	sz[u] = 1; 
    	for(int v, i = head[u]; i; i = e[i].nxt)
    	{
    		v = e[i].to; if(v == fa) continue; 
    		dfs(v, u), sz[u] += sz[v]; 
    	}
    	f[sz[u]]++, f[n - sz[u]]++; 
    }
    
    void ntt(int *p, int opt)
    {
    	for(int i = 0; i < n; i++)
    		if(i < rev[i]) swap(p[i], p[rev[i]]);
    	for(int i = 1; i < n; i <<= 1)
    	{
    		int rt = fpow(opt == 1 ? g : ig, (mod - 1) / (i << 1));
    		for(int j = 0; j < n; j += (i << 1))
    		{
    			int w = 1;
    			for(int k = j; k < j + i; k++, w = 1ll * w * rt % mod)
    			{
    				int x = p[k], y = 1ll * w * p[k + i] % mod;
    				p[k] = (x + y) % mod, p[k + i] = (x - y + mod) % mod; 
    			}
    		}
    	}
    }
    
    int main()
    {
    	n = read <int> (), g = getroot(mod), ig = fpow(g, mod - 2); 
    	cnt = 0, pre(); 
    	for(int u, v, i = 1; i < n; i++)
    	{
    		u = read <int> (), v = read <int> (); 
    		adde(u, v), adde(v, u); 
    	}
    	dfs(1, 0); 
    	for(m = 3 * n, n = 1; n <= m; n <<= 1, lim++); lim--; 
    	for(int i = 0; i < n; i++)
    		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lim);
    	f[m / 3]--;
    	for(int i = 0; i < m / 3; i++)
    		f[i] = 1ll * f[i] * fac[i] % mod; 
    	ntt(f, 1), ntt(h, 1); 
    	for(int i = 0; i < n; i++)
    		f[i] = 1ll * f[i] * h[i] % mod; 
    	ntt(f, -1); 
    	int tmp = fpow(n, mod - 2); 
    	for(int i = 0; i < n; i++)
    		f[i] = 1ll * f[i] * tmp % mod; 
    	m /= 3; 
    	for(int i = 1; i <= m; i++)
    		printf("%lld
    ", (1ll * m * C(m, i) % mod - 1ll * inv[i] * f[m + i] % mod + mod) % mod); 
    	return 0; 
    }
    
  • 相关阅读:
    memcached简单介绍及在django中的使用
    【Python】解决使用pyinstaller打包Tkinker程序报错问题
    【python】获取列表中最长连续数字
    【zabbix】zabbix忘记密码,重置密码
    【jenkins】jenkins实时显示python脚本输出
    【AWS】AWS云计算赋能数字化转型专题研讨会圆满落幕
    【深度学习】使用opencv在视频上添加文字和标记框
    【AWS】AWS云计算赋能数字化转型专题研讨会
    【AWS】订阅AWS论坛的RSS消息获取最新公告
    【saltstack】saltstack执行结果和事件存储到mysql
  • 原文地址:https://www.cnblogs.com/ztlztl/p/12194874.html
Copyright © 2011-2022 走看看