zoukankan      html  css  js  c++  java
  • 【题解】CSP2019 简要题解

    D1T1 code

    签到题,大家都会。

    可以从高位往低位确定,如果遇到 \(1\),则将排名取反一下。

    注意要开 unsigned long long

    #include <bits/stdc++.h>
    
    typedef unsigned long long u64; 
    
    const int MaxN = 100; 
    
    u64 n, K; 
    bool ans[MaxN]; 
    
    inline void solve(u64 dep, u64 k)
    {
    	if (dep == 0)
    		return; 
    	
    	u64 lsze = 1ull << (dep - 1); 
    	if (k < lsze)
    	{
    		ans[dep] = false; 
    		solve(dep - 1, k); 
    	}
    	else
    	{
    		ans[dep] = true; 
    		solve(dep - 1, lsze - (k - lsze) - 1); 
    	}
    }
    
    int main()
    {
        freopen("code.in", "r", stdin); 
        freopen("code.out", "w", stdout); 
        
    	std::cin >> n >> K; 
    	
    	solve(n, K); 
    	
    	for (int i = n; i >= 1; --i)
    		putchar(ans[i] ? '1' : '0'); 
    	
    	return 0; 
    }
    

    D1T2 brackets

    简单题,大家都会。

    大家的做法都好巨,我只会奇奇怪怪的做法。

    考虑每次加一个括号之后答案的增量,显然只有加右括号的时候答案会增加。

    我们记两个量, \(lst\_lef_i\) 表示 \(i\to 1\) 的路径上最后一个没有匹配的左括号,\(lst\_blk_i\) 表示以 \(i\) 结尾的合法子串个数(本质上是数出类似这种串 ...(...)(...)(...)(...) 的极长合法后缀可以分成几段 (...))。

    这两个量可以直接线性推出来,然后就做完了。时间复杂度 \(O(n)\)

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &x)
    {
    	static char ch; 
    	while (!isdigit(ch = getchar())); 
    	x = ch - '0'; 
    	while (isdigit(ch = getchar()))
    		x = x * 10 + ch - '0'; 
    }
    
    typedef long long s64; 
    
    const int MaxNV = 5e5 + 5; 
    const int MaxNE = MaxNV; 
    
    int n; 
    int fa[MaxNV]; 
    char s[MaxNV]; 
    
    int lst_lef[MaxNV]; 
    int lst_blk[MaxNV]; 
    
    s64 ans[MaxNV], xor_ans; 
    
    int main()
    {
        freopen("brackets.in", "r", stdin); 
        freopen("brackets.out", "w", stdout); 
        
    	scanf("%d%s", &n, s + 1); 
    	for (int i = 2; i <= n; ++i)
    		read(fa[i]); 
    	
    	if (s[1] == '(')
    		lst_lef[1] = 1; 
    	
    	for (int u = 2; u <= n; ++u)
    	{
    		ans[u] = ans[fa[u]]; 
    		
    		if (s[u] == '(')
    		{
    			lst_lef[u] = u; 
    			lst_blk[u] = 0; 
    		}
    		else
    		{
    			if (lst_lef[fa[u]])
    			{
    				int lef_u = lst_lef[fa[u]]; 
    				
    				lst_lef[u] = lst_lef[fa[lef_u]]; 
    				lst_blk[u] = lst_blk[fa[lef_u]] + 1; 
    				ans[u] += lst_blk[u]; 
    			}
    			else
    			{
    				lst_lef[u] = 0; 
    				lst_blk[u] = 0; 
    			}
            }
    		xor_ans ^= 1LL * u * ans[u]; 
    	}
    	
    	std::cout << xor_ans << std::endl; 
    	
    	return 0; 
    }
    

    D1T3 tree

    细节题。这个题的 idea 挺好的,就是容易分类讨论挂。

    难度其实不大,放在 D1T3 其实没啥毛病。可能出题人高估了我的代码能力。我太菜了,考场上调不出来。

    一个显然的贪心就是从小到大枚举数字,然后判断这个数字最终能送到那个位置。显然每次我们都贪心地选取最小的位置。

    考虑一条路径 \(u_1\to u_2\to \dots \to u_k\),假设我们要将 \(u_1\) 的原来数字送到 \(u_k\),那么需要满足下列条件:

    • \((u_1,u_2)\) 是和 \(u_1\) 相连的所有边中第一个被删除的
    • \((u_{k - 1}, u_k)\) 是和 \(u_k\) 相连的所有边中最后一个被删除的
    • \((u_{i-1},u_i)\) 必须比 \((u_i,u_{i+1})\) 先删除,并且在删除 \((u_{i-1},u_i)\) 后,删除 \((u_i,u_{i+1})\) 之前,不能有和 \(u_i\) 相连的其他边被删除。

    那么我们就对每个点,维护出与其相连的所有边的限制。这些限制具体可以用一个链表来表示,并且需要记录每个点强制限制的第一个删除的边,和最后一个删除的边。

    实现的时候,就从当前枚举的数字所在的点开始 dfs,显然满足第一个和第三个条件的点构成一个联通块。我们只需要在 dfs 的时候顺带判断这些条件能否满足即可。

    时间复杂度 \(O(n^2)\)。细节比较多,我是根据考场的混乱思路瞎写的,相信读者一定有比我更优秀的实现方法。

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &x)
    {
    	static char ch; 
    	while (!isdigit(ch = getchar())); 
    	x = ch - '0'; 
    	while (isdigit(ch = getchar()))
    		x = x * 10 + ch - '0'; 
    }
    
    template <class T>
    inline void putint(T x)
    {
    	static char buf[25], *tail = buf; 
    	if (!x)
    		putchar('0'); 
    	else
    	{
    		for (; x; x /= 10) *++tail = x % 10 + '0'; 
    		for (; tail != buf; --tail) putchar(*tail); 
    	}
    }
    
    const int MaxN = 2e3 + 5; 
    
    int n; 
    int idx[MaxN], col[MaxN], fa[MaxN]; 
    int adj[MaxN][MaxN], deg[MaxN]; 
    
    int ans[MaxN]; 
    int fir[MaxN], lst[MaxN]; 
    int head[MaxN][MaxN], sze[MaxN][MaxN]; 
    int pre[MaxN][MaxN], suf[MaxN][MaxN]; 
    
    inline void init()
    {
    	read(n); 
    	for (int i = 1; i <= n; ++i)
    	{
    		fir[i] = lst[i] = 0; 
    		deg[i] = ans[i] = fa[i] = 0; 
    
    		for (int j = 1; j <= n; ++j)
    		{
    			head[i][j] = j; 
    			sze[i][j] = 1; 
    			pre[i][j] = suf[i][j] = 0; 
    		}
    	}
    
    	for (int i = 1; i <= n; ++i)
    	{
    		read(idx[i]); 
    		col[idx[i]] = i; 
    	}
    
    	for (int i = 1; i < n; ++i)
    	{
    		int u, v; 
    		read(u), read(v); 
    		adj[u][++deg[u]] = v; 
    		adj[v][++deg[v]] = u; 
    	}
    }
    
    inline void dfs(int u, int src)
    {
    	if (u != src)
    	{
    		bool flg = true; 
    
    		if (deg[u] != 1)
    		{
    			flg &= fir[u] != fa[u] && !suf[u][fa[u]]; 
    			flg &= !lst[u] || lst[u] == fa[u]; 
    			if (fir[u] && head[u][fir[u]] == head[u][fa[u]])
    				flg &= sze[u][head[u][fa[u]]] == deg[u]; 
    		}
    
    		if (flg)
    		{
    			if (!ans[src] || u < ans[src])
    				ans[src] = u; 
    		}
    	}
    	for (int i = 1; i <= deg[u]; ++i)
    	{
    		int v = adj[u][i]; 
    		if (v == fa[u])
    			continue; 
    
    		fa[v] = u; 
    
    		bool flg = true; 
    		if (u == src)
    		{
    			if (deg[u] != 1)
    			{
    				flg &= lst[u] != v && !pre[u][v]; 
    				if (lst[u] && head[u][lst[u]] == head[u][v])
    					flg &= sze[u][head[u][v]] == deg[u]; 
    			}
    		}
    		else
    		{
    			flg &= !suf[u][fa[u]] || suf[u][fa[u]] == v; 
    			flg &= !pre[u][v] || pre[u][v] == fa[u]; 
    			flg &= suf[u][fa[u]] == v || head[u][v] != head[u][fa[u]]; 
    			flg &= head[u][fir[u]] != head[u][v] && head[u][lst[u]] != head[u][fa[u]]; 
    			if (head[u][lst[u]] == head[u][v] && head[u][fir[u]] == head[u][fa[u]])
    				flg &= sze[u][head[u][lst[u]]] + sze[u][head[u][fir[u]]] == deg[u]; 
    		}
    
    		if (flg)
    			dfs(v, src); 
    	}
    }
    
    inline void modify(int x, int src)
    {
    	if (!x)
    		return; 
    
    	lst[x] = fa[x]; 
    
    	int y = x; 
    	while (fa[y] != src)
    	{
    		suf[fa[y]][fa[fa[y]]] = y; 
    		pre[fa[y]][y] = fa[fa[y]]; 
    
    		if (head[fa[y]][fa[fa[y]]] != head[fa[y]][y])
    		{
    			int l = head[fa[y]][fa[fa[y]]]; 
    			int z = y; 
    			while (z)
    			{
    				++sze[fa[y]][l]; 
    				head[fa[y]][z] = l; 
    				z = suf[fa[y]][z]; 
    			}
    		}
    		y = fa[y]; 
    	}
    
    	fir[src] = y; 
    }
    
    inline void solve()
    {
    	for (int c = 1; c <= n; ++c)
    	{
    		int u = idx[c]; 
    		if (n == 1)
    		{
    			puts("1"); 
    			continue; 
    		}
    
    		fa[u] = 0; 
    		dfs(u, u); 
    		modify(ans[u], u); 
    
    		putint(ans[u]); 
    		putchar(" \n"[c == n]); 
    	}
    }
    
    int main()
    {
        freopen("tree.in", "r", stdin); 
        freopen("tree.out", "w", stdout); 
        
    	int orzczk; 
    	read(orzczk); 
    
    	while (orzczk--)
    	{
    		init(); 
    		solve(); 
    	}
    
    	return 0; 
    }
    

    D2T1 meal

    简单题,就我不会。考场上降智太严重了,会了 \(O(mn^3)\) 竟然不会 \(O(mn^2)\)。我校其他选手全部 AC 此题,水平高下立判。

    因为如果有食材超过一半,那么最多只能有一个这样的食材,所以不难想到用总的方案数减去有一个主要食材超过一半的方案数。总的方案数就是

    \[\prod_{i=1}^n\left(1+\sum_{j=1}^ma_{i,j}\right)-1 \]

    减一是因为不能一个都不选。

    考虑如何限制某个食材超过一半,显然我们可以考虑枚举这个食材,然后把用这个食材的菜权值看成 \(+1\),不用这个食材的菜权值看成 \(-1\),那么相当于选的所有菜的总权值要大于 \(0\)

    具体地,我们可以用一个背包 DP 实现。显然大家都会,就不讲了。

    \(f(i,j)\) 表示前 \(i\) 种方法,选的菜的总权值为 \(j\) 的方案数。

    假设现在限制的是第 \(p\) 种食材,那么转移非常显然:

    \[\begin{aligned} f(i,j) &\leftarrow f(i-1,j-1)\times a_{i,p}\\ f(i,j) &\leftarrow f(i-1,j+1)\times \sum_{j \neq p}a_{i,j} \end{aligned} \]

    为了避免负数下标,可以加一个常数。时间复杂度 \(O(mn^2)\)

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &x)
    {
    	static char ch; 
    	while (!isdigit(ch = getchar())); 
    	x = ch - '0'; 
    	while (isdigit(ch = getchar()))
    		x = x * 10 + ch - '0'; 
    }
    
    const int MaxN = 1e2 + 5; 
    const int MaxM = 2e3 + 5; 
    const int mod = 998244353; 
    
    int n, m; 
    int a[MaxN][MaxM], sum[MaxN]; 
    
    int f[MaxN][MaxN << 1]; 
    
    inline void add(int &x, const int &y)
    {
    	x += y; 
    	if (x >= mod)
    		x -= mod; 
    }
    
    inline void dec(int &x, const int &y)
    {
    	x -= y; 
    	if (x < 0)
    		x += mod; 
    }
    
    inline int minus(int x, const int &y)
    {
    	x -= y; 
    	return x < 0 ? x + mod : x; 
    }
    
    int main()
    {
    	freopen("meal.in", "r", stdin); 
    	freopen("meal.out", "w", stdout); 
    
    	read(n), read(m); 
    	for (int i = 1; i <= n; ++i)
    	{
    		for (int j = 1; j <= m; ++j)
    		{
    			read(a[i][j]); 
    			add(sum[i], a[i][j]); 
    		}
    	}
    
    	int ans = 1; 
        for (int i = 1; i <= n; ++i)
            ans = 1LL * ans * (sum[i] + 1) % mod; 
        
        dec(ans, 1); 
    	for (int p = 1; p <= m; ++p)
    	{
    		f[0][n] = 1; 
    		for (int i = 1; i <= n; ++i)
    			for (int j = 0; j <= (n << 1); ++j)
    			{
    				f[i][j] = f[i - 1][j]; 
    				if (j > 0)
    					add(f[i][j], 1LL * f[i - 1][j - 1] * a[i][p] % mod); 
    				if (j < (n << 1))
    					add(f[i][j], 1LL * f[i - 1][j + 1] * minus(sum[i], a[i][p]) % mod); 
    			}
    
    		for (int i = 1; i <= n; ++i)
    			dec(ans, f[n][i + n]); 
    	}
    
    	printf("%d\n", ans); 
    
    	return 0; 
    }
    

    D2T2 partition

    打表找规律题,考场上来不及了。

    开始我们有一个显然的 DP 是,设 \(f(i,j)\) 表示前 \(i\) 个数,最后一段是 \([j+1,i]\),的最小平方和。这样的 DP 实现地优秀一点可以做到 \(O(n^2)\)

    强烈的感觉告诉我们,这题有奇妙结论,考场上当然是打表。打表后不难发现在合法范围内,\(f(i,j)\)\(j\) 单调递减

    结论的证明参考出题人myy的博客: http://matthew99.blog.uoj.ac/blog/5299

    简单总结一下这个证明:

    结论: 把所有解的断点从大到小写下来,然后剩下的位置补0,那么最优解对应的序列在所有位置都是最大值。(不难发现,这个定义使最优解唯一)

    证明: 结论等价于,对于每个解,从后往前将每一段的和写出来,然后补无限个零,得到一个对应的序列,那么最优解对应的序列任意位置的前缀和都是最小的。

    假设这个对应从后往前写出的每一段和的序列为 \(\{b_i\}\),考虑另一个解对应的序列 \(\{c_i\}\),显然不会出现某个位置 \(k\) 的前缀和 \(\sum_{i=1}^kb_i > \sum_{i=1}^kc_i\),否则就会和最优解的定义矛盾。

    因此现在需要证明的就是对于任意一个满足

    \[\forall k, \sum_{i=1}^kb_i\leq \sum_{i=1}^kc_i \]

    的解对应序列 \(c\)\(c\)\(b\) 不同),都有

    \[\sum_{i=1}^{+\infty}b_i^2 < \sum_{i=1}^{+\infty}c_i^2 \]

    证明的思路是,将序列 \(c\) 经过一些使平方和减小的调整,并且使得任意时刻都满足所有位置的前缀和不小于 \(b\),最后让 \(c\) 变成 \(b\)。这样就能证明 \(c\) 的平方和不小于 \(b\) 的。(注意调整的含义是直接对 \(c\) 进行修改,在调整过程中没有必要保证 \(c\) 存在对应的原序列中的解,我们只关心 \(c\)\(b\) 平方和的大小关系)

    注意到对于一个单调不增的序列 \(a\),若 \(i<j\)\(a_i>a_{i+1},a_{j-1}>a_j, a_i-a_j\geq 2\),将 \(a_i\) 减一,\(a_j\) 加一,可以使 \(a\) 仍然单调不增,并且平方和减小。

    找到第一个满足 \(c_u>b_u\) 的位置 \(u\),在 \(u\) 的后面一定能找到第一个位置 \(v\) 满足 \(c_v<b_v\)(因为 \(c\)\(b\) 所有元素的总和一样)。因为 \(c\)\(b\) 都单调不增,所以 \(c_u-c_v\geq 2\),即区间 \([u,v]\) 的权值跨度至少为 \(2\)。找到最小的满足 \(i\geq u,c_{i}>c_{i+1}\) 的位置 \(i\),找到最大的满足 \(j\leq v,c_{j-1}>c_j\) 的位置 \(j\),那么 \(u\leq i<j\leq v\),并且 \(c_i-c_j\geq 2\)。于是将 \(c_i\) 减一,\(c_j\) 加一,可以使 \(c\) 仍然递增,并且平方和减小,不难发现,这么操作仍然保证了 \(c\) 的每个位置的前缀和不小于 \(b\) 的。

    于是不断这么操作,一定能使 \(c\) 最终变成 \(b\),而在操作过程中平方和不断减小,于是原来的 \(c\) 的平方和是大于 \(b\) 的。\(\square\)

    因此我们对于每个 \(i\) 只要记一个最优决策点即可,每个 \(i\) 的决策点一定是取在最靠右的合法位置。记这个位置为 \(p_i\)。考虑一个决策点 \(j(j<i)\) 是合法的当且仅当 \(s_i-s_j\geq s_j-s_{p_j}\),其中 \(s_i\) 表示 \(1\dots i\) 的前缀和。

    移项一下这个条件就是 \(2s_j-s_{p_j} \leq s_i\),到这里可以用一个BIT 做到 \(O(n \log n)\)

    但是实际上,不难发现 \(s_i\) 是单增的,也就是说一个决策点 \(j\) 对于当前的 \(i\) 是合法的,那么对于后面的肯定仍是合法的。于是我们考虑维护一个\(2s_j-s_{p_j}\) 单增的单调队列。对于每个 \(i\) 就从队头开始找到最后一个合法决策点,再将 \(i\) 插到队尾,并且将队尾的那些 \(2s_j-s_{p_j}\geq 2s_i-s_{p_i}\) 的决策点弹掉。这样的时间复杂度就是 \(O(n)\) 了。

    本题还有个问题就是,你不能一边做这个 DP 一边算这个 DP 值,得把 \(p_i\) 存起来最后算(因为空间不够)。高精还需要压位/二进制。当然,你在 OJ 上用 __int128 也是可以的。

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &x)
    {
    	static char ch; 
    	static bool opt; 
    	while (!isdigit(ch = getchar()) && ch != '-'); 
    	x = (opt = ch == '-') ? 0 : ch - '0'; 
    	while (isdigit(ch = getchar()))
    		x = x * 10 + ch - '0'; 
    	if (opt)
    		x = ~x + 1; 
    }
    
    template <class T>
    inline void putint(T x)
    {
    	static char buf[45], *tail = buf; 
    	if (!x)
    		putchar('0'); 
    	else
    	{
    		if (x < 0)
    		{
    			putchar('-'); 
    			x = ~x + 1; 
    		}
    		for (; x; x /= 10) *++tail = x % 10 + '0'; 
    		for (; tail != buf; --tail) putchar(*tail); 
    	}
    }
    
    typedef long long s64; 
    
    const int MaxN = 4e7 + 5; 
    const s64 mod = 1e9; 
    
    int n, type, ql, qr; 
    int que[MaxN], maxp[MaxN], b[MaxN]; 
    s64 s[MaxN]; 
    
    struct bignum
    {
    	int len; 
    	s64 a[7]; 
    	bignum(){}
    	bignum(s64 t)
    	{
    		len = 1; 
    		memset(a, 0, sizeof(a)); 
    
    		a[1] = t % mod; 
    		if (t >= mod)
    			a[++len] = t / mod; 
    	}
    
    	inline void operator += (const bignum &rhs)
    	{
    		len = std::max(len, rhs.len); 
    		for (int i = 1; i <= len; ++i)
    		{
    			a[i] += rhs.a[i]; 
    			a[i + 1] += a[i] / mod; 
    			a[i] %= mod; 
    		}
    		if (a[len + 1])
    			++len; 
    	}
    
    	inline bignum operator * (const bignum &rhs) const
    	{
    		bignum res(0); 
    		res.len = len + rhs.len; 
    		for (int i = 1; i <= len; ++i)
    			for (int j = 1; j <= rhs.len; ++j)
    				res.a[i + j - 1] += a[i] * rhs.a[j]; 
    		for (int i = 1; i < res.len; ++i)
    		{
    			res.a[i + 1] += res.a[i] / mod; 
    			res.a[i] %= mod; 
    		}
    		while (res.len > 1 && !res.a[res.len])
    			--res.len; 
    		return res; 
    	}
    
    	inline void print()
    	{
    		printf("%d", (int)a[len]); 
    		for (int i = len - 1; i >= 1; --i)
    			printf("%09d", (int)a[i]); 
    	}
    }res(0); 
    
    inline s64 calc(int x)
    {
    	return 2 * s[x] - s[maxp[x]]; 
    }
    
    int main()
    {
    	freopen("partition.in", "r", stdin); 
    	freopen("partition.out", "w", stdout); 
    
    	read(n), read(type); 
    	if (type == 0)
    	{
    		for (int i = 1; i <= n; ++i)
    		{
    			int x; 
    			read(x); 
    			s[i] = s[i - 1] + x; 
    		}
    	}
    	else
    	{
    		int x, y, z, m; 
    		read(x), read(y), read(z), read(b[1]), read(b[2]), read(m); 
    		for (int i = 1, lstp = 0; i <= m; ++i)
    		{
    			int p, l, r; 
    			read(p), read(l), read(r); 
    			for (int j = lstp + 1; j <= p; ++j)
    			{
    				if (j > 2)
    					b[j] = (1LL * x * b[j - 1] + 1LL * y * b[j - 2] + z) % (1 << 30); 
    				s[j] = s[j - 1] + b[j] % (r - l + 1) + l; 
    			}
    			lstp = p; 
    		}
    	}
    	
    	que[ql = qr = 1] = 0; 
    	for (int i = 1; i <= n; ++i)
    	{
    		while (ql < qr && calc(que[ql + 1]) <= s[i])
    			++ql; 
    		maxp[i] = que[ql]; 
    		while (ql <= qr && calc(que[qr]) >= calc(i))
    			--qr; 
    		que[++qr] = i; 
    	}
    
    	int cur = n; 
    	while (cur)
    	{ 
    		res += bignum(s[cur] - s[maxp[cur]]) * bignum(s[cur] - s[maxp[cur]]); 
    		cur = maxp[cur]; 
    	}
    
    	res.print(); 
    	
    	return 0; 
    }
    

    D2T3 centroid

    不难题,就我不会。因为这是 D2T3 所以我只想着写暴力,实际上这题挺简单的。

    将题目中的计算方式转化为:对于每个点,计算它成为重心的方案数。

    那么考虑一个点 \(u\),将其硬点为,对于它的每个儿子 \(v\),相当于在 \(v\) 的子树再选出一个大小为 \(s_0\) 的子树,这个 \(s_0\) 需要在一个范围。具体地,我们记 \(s'\) 表示 \(u\) 除了 \(v\) 以外的其他儿子的最大子树大小,记 \(s_v\) 表示 \(v\) 的子树大小。

    那么就有

    \[s_v-s_0 \leq \left\lfloor\frac {n-s_0}2\right\rfloor \\ s_0 \geq 2s_v-n \]

    以及

    \[s^\prime\leq \left\lfloor \frac {n-s_0}2\right\rfloor\\ s_0\leq n-2s^\prime \]

    那么就要求 \(s_0\in [2s_v-n,n-2s^\prime]\)

    对应到原树(钦定 \(1\) 为根的有根树)上,这相当于分两类讨论:

    • \(v\)\(u\) 的一个儿子,这个时候直接在 dfs 时用一个 BIT 存下当前结点到根的路径上所有点的询问区间对应的贡献,然后遍历到一个点的时候在 BIT 上单点查询即可。
    • \(v\)\(u\) 的父亲,这时候的统计比较复杂,需要分成几块统计:
      • 原树中不在 \(u\) 子树中且不在 \(u\) 到根路径上的结点的 \(size\in [2(n-s_u)-n,n-2s^\prime]\),这个可以用整棵树\(size\) 在这个区间的点数减去子树内 \(size\) 在这个区间的点数,再减去到根结点的路径上的点 \(size\) 在这个区间的点数。子树询问可以归到第一类的 BIT 中,到根结点的可以另外再维护一个 BIT。
      • 原树中 \(u\) 到根的路径上,向父亲方向的「子树」。这个可以 dfs 的时候归到上面第二个 BIT 维护。

    那么这样我们就可以用 BIT 来做了。时间复杂度 \(O(n\log n)\),常数有点大,可能打不过线段树选手。

    (补充:我太菜了,其实那些询问看成二维数点然后离线 BIT 就可以了,不用我这么麻烦,但是这样常数好像也很大

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &x)
    {
    	static char ch; 
    	while (!isdigit(ch = getchar())); 
    	x = ch - '0'; 
    	while (isdigit(ch = getchar()))
    		x = x * 10 + ch - '0'; 
    }
    
    template <class T>
    inline void putint(T x)
    {
    	static char buf[25], *tail = buf; 
    	if (!x)
    		putchar('0'); 
    	else
    	{
    		for (; x; x /= 10) *++tail = x % 10 + '0'; 
    		for (; tail != buf; --tail) putchar(*tail); 
    	}
    }
    
    template <class T>
    inline void relax(T &x, const T &y)
    {
    	if (x < y)
    		x = y; 
    }
    
    template <class T>
    inline void tense(T &x, const T &y)
    {
    	if (x > y)
    		x = y; 
    }
    
    typedef long long s64; 
    
    const int MaxNV = 3e5 + 5; 
    const int MaxNE = MaxNV << 1; 
    
    int n; 
    int ect, adj[MaxNV]; 
    int nxt[MaxNE], to[MaxNE]; 
    
    int fa[MaxNV], sze[MaxNV]; 
    s64 ans, sum[MaxNV], bit_q[MaxNV], bit_u[MaxNV]; 
    
    int max_sze[MaxNV][2]; 
    
    #define trav(u) for (int e = adj[u], v; v = to[e], e; e = nxt[e])
    
    inline void addEdge(int u, int v)
    {
    	nxt[++ect] = adj[u]; 
    	adj[u] = ect; 
    	to[ect] = v; 
    }
    
    inline void init()
    {
    	read(n); 
    
    	ect = 0; 
    	ans = 0; 
    	for (int i = 1; i <= n; ++i)
    	{
    		adj[i] = sum[i] = 0; 
    		bit_q[i] = bit_u[i] = 0; 
    		max_sze[i][0] = max_sze[i][1] = 0; 
    	}
    
    	for (int i = 1; i < n; ++i)
    	{
    		int u, v; 
    		read(u), read(v); 
    		addEdge(u, v); 
    		addEdge(v, u); 
    	}
    }
    
    inline void bit_modify(int x, int val, s64 *bit)
    {
    	for (; x <= n; x += x & -x)
    		bit[x] += val; 
    }
    
    inline s64 bit_query(int x, s64 *bit)
    {
    	s64 res = 0; 
    	for (; x; x ^= x & -x)
    		res += bit[x]; 
    	return res; 
    }
    
    inline void seg_modify(int l, int r, int val, s64 *bit)
    {
    	relax(l, 1), tense(r, n); 
    	if (r < l) return; 
    
    	bit_modify(l, val, bit); 
    	bit_modify(r + 1, -val, bit); 
    }
    
    inline s64 seg_query(int l, int r, s64 *bit)
    {
    	relax(l, 1), tense(r, n); 
    	if (r < l) return 0; 
    	return bit_query(r, bit) - bit_query(l - 1, bit); 
    }
    
    inline void upt(int x, int s)
    {
    	if (s >= max_sze[x][0])
    	{
    		max_sze[x][1] = max_sze[x][0]; 
    		max_sze[x][0] = s; 
    	}
    	else
    		relax(max_sze[x][1], s); 
    }
    
    inline int max_else(int x, int s)
    {
    	if (s == max_sze[x][0])
    		return max_sze[x][1]; 
    	return max_sze[x][0]; 
    }
    
    inline void dfs_init(int u)
    {
    	sze[u] = 1; 
    	trav(u)
    		if (v != fa[u])
    		{
    			fa[v] = u; 
    			dfs_init(v); 
    
    			sze[u] += sze[v]; 
    			upt(u, sze[v]); 
    		}
    	upt(u, n - sze[u]); 
    }
    
    inline void dfs(int u)
    {
    	int ul = std::max(1, 2 * (n - sze[u]) - n); 
    	int ur = std::min(n, n - 2 * max_else(u, n - sze[u])); 
    	if (ul <= ur)
    		ans += (sum[ur] - sum[ul - 1]) * u; 
    	seg_modify(ul, ur, -u, bit_q); 
    	ans += bit_query(sze[u], bit_q); 
    	ans += seg_query(ul, ur, bit_u) * u; 
    
    	trav(u)
    		if (v != fa[u])
    		{
    			int t_s = max_else(u, sze[v]); 
    			int l = 2 * sze[v] - n, r = n - 2 * t_s; 
    			
    			seg_modify(l, r, u, bit_q); 
    			bit_modify(sze[u], -1, bit_u); 
    			bit_modify(n - sze[v], 1, bit_u); 
    
    			dfs(v); 
    
    			seg_modify(l, r, -u, bit_q);
    			bit_modify(sze[u], 1, bit_u); 
    			bit_modify(n - sze[v], -1, bit_u);  
    		}
    
    	seg_modify(ul, ur, u, bit_q); 
    }
    
    inline void solve()
    {
    	dfs_init(1); 
    
    	for (int i = 1; i <= n; ++i)
    		++sum[sze[i]]; 
    	for (int i = 1; i <= n; ++i)
    		sum[i] += sum[i - 1]; 
    
    	dfs(1); 
    
    	putint(ans); 
    	putchar('\n'); 
    }
    
    int main()
    {
    	freopen("centroid.in", "r", stdin); 
    	freopen("centroid.out", "w", stdout); 
    
    	int orzczk; 
    	read(orzczk); 
    	while (orzczk--)
    	{
    		init(); 
    		solve(); 
    
    	}
    	return 0; 
    }
    
  • 相关阅读:
    软件设计和开发是手艺活也是艺术活
    学界老师和业界专业人员的紧密合作才能促进软件设计开发教学的进步
    最简单的 GitExtensions 教程(持续更新中)
    最简单的 IntelliJ IDEA 中使用 GitHub 进行版本控制教程(持续更新中)
    工作室成员 GitHub 地址集中贴(按发布时间先后排序)
    使用 Visual Studio Code 运行 C# 及 Java 程序
    推荐一个非常好的 IntelliJ IDEA 教程
    Commit message 和 Change log 编写指南(转自阮一峰的博客)
    关于编码规范的延伸资料(来自于福州大学陈世发同学的博客)
    【扩展阅读】提问的智慧(转自福州大学陈世发同学的评论)
  • 原文地址:https://www.cnblogs.com/cyx0406/p/11908087.html
Copyright © 2011-2022 走看看