zoukankan      html  css  js  c++  java
  • [题解][Codeforces]Codeforces Round #635 (Div. 1) 简要题解

    • Chinese Round 果然对中国选手十分友好(

    • 原题解

    A

    题意

    • 给定一棵 (n) 个节点的有根树和一个 (k),满足 (1le kle n)

    • 选出 (k) 个点为黑点,其他点为白点

    • 求所有黑点到根的路径上白点个数之和的最大值

    • (1le nle 2 imes 10^5)

    做法:贪心

    • 显然一个点为黑点则其子树全为黑点

    • 故问题可以视为 (k) 次,每次删掉一个叶子 (u),贡献为原树(dep_u-size_u)

    • 由于父亲的 (dep-size) 一定小于子节点,故取 (dep-size) 从大到小排序之后前 (k) 大的即可

    • (O(nlog n))

    • 利用 nth_element 可以做到 O(n)

    代码

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &res)
    {
    	res = 0; bool bo = 0; char c;
    	while (((c = getchar()) < '0' || c > '9') && c != '-');
    	if (c == '-') bo = 1; else res = c - 48;
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = (res << 3) + (res << 1) + (c - 48);
    	if (bo) res = ~res + 1;
    }
    
    typedef long long ll;
    
    const int N = 2e5 + 5, M = N << 1;
    
    int n, k, ecnt, nxt[M], adj[N], go[M], dep[N], fa[N], d[N], sze[N], a[N];
    ll ans;
    
    void add_edge(int u, int v)
    {
    	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
    	nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
    }
    
    void dfs(int u, int fu)
    {
    	fa[u] = fu; dep[u] = dep[fu] + 1; sze[u] = 1;
    	for (int e = adj[u], v; e; e = nxt[e])
    		if ((v = go[e]) != fu) dfs(v, u), d[u]++, sze[u] += sze[v];
    }
    
    int main()
    {
    	int x, y;
    	read(n); read(k);
    	for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
    	dfs(1, 0);
    	for (int i = 1; i <= n; i++) a[i] = dep[i] - sze[i];
    	std::sort(a + 1, a + n + 1);
    	for (int i = n - k + 1; i <= n; i++) ans += a[i];
    	return std::cout << ans << std::endl, 0;
    }
    

    B

    题意

    • 给定三个长度分别为 (n_r,n_g,n_b) 的数组 (r,g,b)

    • 从三个数组中各选一个数,设为 (x,y,z),求 ((x-y)^2+(y-z)^2+(z-x)^2) 的最小值

    • (1le n_r,n_g,n_ble 10^5)(1le r_i,g_i,b_ile 10^9)

    做法:枚举+双指针

    • 假设 (xle yle z),则最优情况下 (x) 要尽可能大,(y) 要尽可能小

    • 故把三个数组排序,枚举 (x,y,z) 大小关系的 (6) 种排列之后,枚举 (y) 的值,用指针维护最大的 (x) 和最小的 (z)

    • (O(n_rlog n_r+n_glog n_g+n_blog n_b))

    代码

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &res)
    {
    	res = 0; bool bo = 0; char c;
    	while (((c = getchar()) < '0' || c > '9') && c != '-');
    	if (c == '-') bo = 1; else res = c - 48;
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = (res << 3) + (res << 1) + (c - 48);
    	if (bo) res = ~res + 1;
    }
    
    typedef long long ll;
    
    const int N = 1e5 + 5;
    const ll INF = 5e18;
    
    int nr, ng, nb, r[N], g[N], b[N];
    
    ll sqr(int x) {return 1ll * x * x;}
    
    ll solve(int na, int nb, int nc, int *a, int *b, int *c)
    {
    	ll ans = INF;
    	for (int i = 1, j = 1, k = 1; j <= nb; j++)
    	{
    		while (i <= na && a[i] <= b[j]) i++;
    		while (k <= nc && b[j] > c[k]) k++;
    		if (i > 1 && k <= nc) ans = std::min(ans,
    			sqr(a[i - 1] - b[j]) + sqr(b[j] - c[k]) + sqr(c[k] - a[i - 1]));
    	}
    	return ans;
    }
    
    void work()
    {
    	read(nr); read(ng); read(nb);
    	for (int i = 1; i <= nr; i++) read(r[i]);
    	for (int i = 1; i <= ng; i++) read(g[i]);
    	for (int i = 1; i <= nb; i++) read(b[i]);
    	std::sort(r + 1, r + nr + 1); std::sort(g + 1, g + ng + 1);
    	std::sort(b + 1, b + nb + 1);
    	ll ans = solve(nr, ng, nb, r, g, b);
    	ans = std::min(ans, solve(nr, nb, ng, r, b, g));
    	ans = std::min(ans, solve(nb, nr, ng, b, r, g));
    	ans = std::min(ans, solve(nb, ng, nr, b, g, r));
    	ans = std::min(ans, solve(ng, nr, nb, g, r, b));
    	ans = std::min(ans, solve(ng, nb, nr, g, b, r));
    	printf("%lld
    ", ans);
    }
    
    int main()
    {
    	int T; read(T);
    	while (T--) work();
    	return 0;
    }
    

    C

    题意

    • 给定长度为 (n) 的串 (S) 和长度为 (m) 的串 (T)

    • 一开始有一个空串 (A)

    • 每次操作可以选择把 (S) 的第一个字符加入 (A) 的开头或末尾,并把 (S) 的第一个字符删掉

    • 你可以执行任意不超过 (n) 的操作次数,求最后能使得 (T)(A) 的前缀的方案数,对 (998244353) 取模

    • (1le mle nle 3000)

    做法:区间 DP

    • (f[l,r]) 表示插入了 (S) 的前 (r-l+1) 个字符,它们组成了最终的 (A) 串的区间 ([l,r]) 的方案数

    • 组成最终的 (A) 串的区间 ([l,r]),也就是说若 (iin[l,r])(ile m),则 (A_i=T_i)

    • 转移即枚举最后一个字符加在左边还是右边,判断其是否符合限制条件即可

    • 答案为 (sum_{i=m}^nf[1,i])

    • (O(n^2))

    代码

    #include <bits/stdc++.h>
    
    const int N = 3005, djq = 998244353;
    
    int n, m, f[N][N], ans;
    char s[N], t[N];
    
    int main()
    {
    	scanf("%s%s", s + 1, t + 1);
    	n = strlen(s + 1); m = strlen(t + 1);
    	for (int i = 1; i <= n + 1; i++) f[i][i - 1] = 1;
    	for (int l = n; l >= 1; l--)
    		for (int r = l; r <= n; r++)
    		{
    			if (l > m || s[r - l + 1] == t[l]) f[l][r] += f[l + 1][r];
    			if (r > m || s[r - l + 1] == t[r]) f[l][r] += f[l][r - 1];
    			if (f[l][r] >= djq) f[l][r] -= djq;
    			if (l == 1 && r >= m)
    				ans = (ans + f[l][r]) % djq;
    		}
    	return std::cout << ans << std::endl, 0;
    }
    

    D

    题意

    • 交互题

    • 你有一堆麻将,点数从 (1)(n),每种点数的麻将个数在 ([0,n]) 之间,但你不知道它们具体是多少

    • 初始时可以知道这堆麻将中,碰(大小为 (3) 且点数相同的子集)的个数和吃(大小为 (3) 且点数形成公差为 (1) 的等差数列)的个数

    • 然后你可以加入最多 (n) 次某一种点数的麻将,加入一个麻将之后你可以得到此时碰和吃的个数

    • 还原初始时每种点数的麻将个数

    • (4le nle 100)

    做法:数学

    • 当前(i) 种麻将有 (c_i) 个,则加入一个第 (i) 种麻将时会多出 (inom{c_i}2) 个碰和 (c_{i-2}c_{i-1}+c_{i-1}c_{i+1}+c_{i+1}c_{i+2}) 个吃

    • 如果只考虑吃的个数,则如果保证 (c_i>0) 则可以通过碰的个数的增量还原出 (c_i)

    • 考虑求点数为 (1) 的个数,可以得到如果事先加入一个 (1),就能保证 (c_i>0),再加入一个 (1) 即可查出 (ans_1)

    • 而加入 (1) 的好处是吃的个数增量为 (c_2c_3)

    • 于是考虑依次加入 (3,1,2,1),这样第二次吃的个数增量为 (ans_2(ans_3+1)),第四次吃的个数增量为 ((ans_2+1)(ans_3+1))

    • 这两个式子作差即可得到 (ans_3)。由于 (ans_3+1>0),故可以使用除法得到 (ans_2)

    • 而实际上我们也可以得到 (ans_4):考虑第三次吃的个数增量:((ans_3+1)(ans_1+1+ans_4)),也可以利用除法得到

    • 而对于 (i>4),也可以加入一个 (i-2),这时吃的个数增量表达式中只有 (ans_i) 是未知量,可以解出来。不过这样有一个问题:(ans_{i-1}) 可能为 (0),这样的方程会有无穷多个解

    • 故考虑倒着加:(n-1,n-2,dots,3,1,2,1)

    • 易得 (3,1,2,1) 移到最后不影响 (ans_{1dots 4}) 的求解,只是 (n>4) 时这样求解出来的 (ans_4) 需要减 (1)(在 (n-1,n-2,dots 4) 中加上了 (1)

    • 然后 (i)(3)(n-2),利用 (i) 被加入时吃的个数增量来解出 (ans_{i+2}),由于 (i+1) 在之前的过程中加过了 (1),故可以保证 (c_{i+1}) 不为 (0),这个方程一定可以解出来

    • (O(n)),操作次数为 (n)

    代码

    #include <bits/stdc++.h>
    
    const int N = 110, M = N * N;
    
    int n, ans[N], f[M], a[N], b[N];
    
    void add(int v) {printf("+ %d
    ", v); fflush(stdout);}
    
    int main()
    {
    	scanf("%d", &n);
    	for (int i = 1; i <= n + 1; i++) f[i * (i - 1) >> 1] = i;
    	scanf("%*d%*d");
    	for (int i = 1; i <= n - 4; i++) add(n - i), scanf("%d%d", &a[i], &b[i]);
    	add(3); scanf("%d%d", &a[n - 3], &b[n - 3]);
    	add(1); scanf("%d%d", &a[n - 2], &b[n - 2]);
    	add(2); scanf("%d%d", &a[n - 1], &b[n - 1]);
    	add(1); scanf("%d%d", &a[n], &b[n]);
    	ans[1] = f[a[n] - a[n - 1]] - 1;
    	ans[3] = (b[n] - b[n - 1]) - (b[n - 2] - b[n - 3]) - 1;
    	ans[2] = (b[n] - b[n - 1]) / (ans[3] + 1) - 1;
    	ans[4] = (b[n - 1] - b[n - 2]) / (ans[3] + 1) - (ans[1] + 1) - (n > 4);
    	for (int i = n - 3; i >= 2; i--)
    	{
    		int x = n - i;
    		ans[x + 2] = (b[i] - b[i - 1] - ans[x - 2] * ans[x - 1] - ans[x - 1]
    			* (ans[x + 1] + 1)) / (ans[x + 1] + 1) - (i > 2);
    	}
    	printf("! ");
    	for (int i = 1; i <= n; i++) printf("%d ", ans[i]);
    	return puts(""), 0;
    }
    

    E1

    题意

    • 给定 (n)([0,2^m)) 内的数

    • 对于所有的 (0le ile m),求这些数有多少个子集的异或和,二进制下 (1) 的个数为 (i)

    • (1le nle 2 imes10^5)(0le mle 35)

    做法:线性基+枚举((k) 较小)/DP((k) 较大)

    • 由于 E2 比 E1 难太太太多,就分开讲了

    • 显然先求线性基,设这个基由 (k) 个元素组成

    • 原一个子集的异或和可以表示成线性基内一个子集的异或和,再选上线性基外的一部分 (0),也就是线性基内一个子集的贡献为 (2^{n-k})

    • (k) 较小的时候,可以暴力枚举每个基变量是否选上:(O(2^k))

    • (k) 较大的时候,可以高斯消元求出简化阶梯矩阵(若矩阵第 (i) 行第 (i) 列为 (1) 则第 (i) 列的其他元素均为 (0)),然后 DP (f_{i,j,S}) 表示前 (i) 个基变量中选出了 (j) 个,不在基上的位异或和为 (S) 的方案数,统计答案时答案 (ans_{j+popcount(S)}+=f_{m-k,j,S})(O(2^{m-k}k^2))

    • 结合这两种算法可过 E1

    代码

    #include <bits/stdc++.h>
     
    template <class T>
    inline void read(T &res)
    {
    	res = 0; bool bo = 0; char c;
    	while (((c = getchar()) < '0' || c > '9') && c != '-');
    	if (c == '-') bo = 1; else res = c - 48;
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = (res << 3) + (res << 1) + (c - 48);
    	if (bo) res = ~res + 1;
    }
     
    typedef long long ll;
     
    const int N = 2e5 + 5, E = 40, C = 17000, djq = 998244353;
     
    int n, m, orz = 1, cnt1, p1[N], cnt0, p0[N], f[E][E][C], st[E], ans[E];
    ll a[N], b[E];
     
    void ins(ll x)
    {
    	for (int i = m - 1; i >= 0; i--)
    	{
    		if (!((x >> i) & 1)) continue;
    		if (b[i] == -1) return (void) (b[i] = x);
    		else x ^= b[i];
    	}
    	orz = (orz << 1) % djq;
    }
     
    int cc(ll x)
    {
    	int res = 0;
    	while (x) res += x & 1, x >>= 1;
    	return res;
    }
     
    int main()
    {
    	read(n); read(m);
    	for (int i = 0; i < m; i++) b[i] = -1;
    	for (int i = 1; i <= n; i++) read(a[i]), ins(a[i]);
    	for (int i = 0; i < m; i++) if (b[i] != -1)
    		for (int j = i + 1; j < m; j++)
    			if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
    	for (int i = 0; i < m; i++)
    		if (b[i] != -1) p1[++cnt1] = i;
    		else p0[++cnt0] = i;
    	if (cnt1 <= 20)
    	{
    		for (int S = 0; S < (1 << cnt1); S++)
    		{
    			ll T = 0;
    			for (int i = 1; i <= cnt1; i++)
    				if ((S >> i - 1) & 1) T ^= b[p1[i]];
    			ans[cc(T)]++;
    		}
    	}
    	else
    	{
    		for (int i = 1; i <= cnt1; i++)
    			for (int j = 1; j <= cnt0; j++)
    				if ((b[p1[i]] >> p0[j]) & 1) st[i] |= 1 << j - 1;
    		f[0][0][0] = 1;
    		for (int i = 0; i < cnt1; i++)
    			for (int j = 0; j <= i; j++)
    				for (int S = 0; S < (1 << cnt0); S++)
    				{
    					f[i + 1][j][S] = (f[i + 1][j][S] + f[i][j][S]) % djq;
    					f[i + 1][j + 1][S ^ st[i + 1]] = (f[i + 1][j + 1][S ^ st[i + 1]]
    						+ f[i][j][S]) % djq;
    				}
    		for (int j = 0; j <= cnt1; j++)
    			for (int S = 0; S < (1 << cnt0); S++)
    			{
    				int x = j + cc(S);
    				ans[x] = (ans[x] + f[cnt1][j][S]) % djq;
    			}
    	}
    	for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
    	puts("");
    	return 0;
    }
    

    E2

    题意

    • 同 E1,(0le mle 53)

    做法:FWT+组合数学

    • 妙啊!!!( imes 4)

    • 考虑对于 E1 的第二种算法,把复杂度去掉两个 (k)

    • (A_S) 表示 (S) 是否能被线性基表出,(F^c_S) 表示 (S)(1) 的个数是否为 (c)

    • 我们不难 (neng) 想到 (ans_c) 等于 (FWT(A) imes FWT(F^c)) 所有项之和(这里的 ( imes) 是点乘)除以 (2^m) 后的结果(因为要做 IFWT)

    • 接下来考虑 (FWT(A)) 的性质

    (FWT(A)) 仅由 (0)(2^k) 组成,且第 (S) 位为 (2^k) 当且仅当 (S) 与线性基内所有变量的交集大小都是偶数

    • 证明:

    (S) 与所有基变量的交集大小都是偶数,由于 (S)(Tigoplus U) 的交集大小在奇偶性上等于 (Scap T)(Scap U) 的大小之和,故 (S) 与这个基表出的所有 (2^k) 个数的交集大小都为偶数,由 FWT 的定义可知 (FWT(A)) 的第 (S) 位为 (2^k)
    否则 (S) 与这个基表出的所有 (2^k) 个数的交集大小中奇偶各占一半,由 FWT 的定义可知 (FWT(A)) 的第 (S) 位为 (0)

    另一个性质:

    (FWT(A)) 中为 (2^k) 的位只有 (2^{m-k}) 个,且组成另一个基

    • 证明:

    (FWT(A)) 中第 (S) 位为 (2^k) 的条件转化一下:对于一个不在基上的位 (i),如果让第 (i) 位为 (1),则对于每个满足第 (i) 位为 (1) 的基变量 (j),要让 (S) 的第 (j) 位也异或上 (1)
    这样就有了 (m-k) 个基变量,由于每个基变量的最低位互不相同,故它们可以组成一个基
    但原线性基必须是简化阶梯矩阵,否则在基上的位 (i) 也会对其他在基上的位 (j) 造成影响

    • 于是求出这个大小为 (m-k) 的基后暴力枚举每个变量选或不选,即可得到 (FWT(A)) 中所有为 (2^k) 的位

    • 再考虑 (FWT(F^c)),容易发现 (FWT(F^c)) 的第 (S) 位值只和 (S)(1) 的个数有关

    • 即对于 (S),枚举一个 (1) 的个数为 (c)(T) 贡献 ((-1)^{|Scap T|}),相当于枚举一个 (i) 表示 (S)(T) 表示 (S)(T) 的交集大小

    • 于是 (FWT(F^c)) 包含 (d)(1) 的位值均为:

    • [w_{c,d}=sum_{i=0}^{min(c,d)}(-1)^iinom diinom{m-d}{c-i} ]

    • (FWT(A)) 中含 (c)(1) 的下标有 (q_c)(2^k),则:

    • [ans_c=frac 1{2^{m-k}}sum_{d=0}^mq_dw_{c,d} ]

    • 结合 (k) 较小的暴力枚举,复杂度为 (O(2^{frac m2}+m^3+n))

    代码

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &res)
    {
    	res = 0; bool bo = 0; char c;
    	while (((c = getchar()) < '0' || c > '9') && c != '-');
    	if (c == '-') bo = 1; else res = c - 48;
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = (res << 3) + (res << 1) + (c - 48);
    	if (bo) res = ~res + 1;
    }
    
    typedef long long ll;
    
    const int N = 60, djq = 998244353, i2 = 499122177;
    
    int n, m, orz = 1, cnt1, p[N], cnt0, cnt[N], ans[N], C[N][N];
    ll b[N], a[N];
    
    void ins(ll x)
    {
    	for (int i = m - 1; i >= 0; i--)
    	{
    		if (!((x >> i) & 1)) continue;
    		if (b[i] == -1) return (void) (b[i] = x);
    		else x ^= b[i];
    	}
    	orz = (orz << 1) % djq;
    }
    
    void dfs(int dep, int tar, ll T)
    {
    	if (dep == tar + 1) return (void) (ans[__builtin_popcountll(T)]++);
    	dfs(dep + 1, tar, T); dfs(dep + 1, tar, T ^ a[dep]);
    }
    
    int main()
    {
    	ll x;
    	read(n); read(m);
    	for (int i = 0; i < m; i++) b[i] = -1;
    	for (int i = 1; i <= n; i++) read(x), ins(x);
    	for (int i = 0; i < m; i++) if (b[i] != -1)
    		for (int j = i + 1; j < m; j++)
    			if (b[j] != -1 && ((b[j] >> i) & 1)) b[j] ^= b[i];
    	for (int i = 0; i < m; i++) if (b[i] != -1) a[++cnt1] = b[i];
    	if (cnt1 <= 26) dfs(1, cnt1, 0);
    	else
    	{
    		for (int i = 0; i < m; i++) if (b[i] == -1)
    		{
    			a[++cnt0] = 1ll << i;
    			for (int j = i + 1; j < m; j++) if (b[j] != -1 && ((b[j] >> i) & 1))
    				a[cnt0] |= 1ll << j;
    		}
    		dfs(1, cnt0, 0);
    		for (int i = 0; i <= m; i++) cnt[i] = ans[i], ans[i] = 0, C[i][0] = 1;
    		for (int i = 1; i <= m; i++)
    			for (int j = 1; j <= i; j++)
    				C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % djq;
    		int I = 1;
    		for (int i = 1; i <= cnt0; i++) I = 1ll * I * i2 % djq;
    		for (int i = 0; i <= m; i++)
    			for (int j = 0; j <= m; j++)
    			{
    				int pl = 0;
    				for (int k = 0; k <= j && k <= i; k++)
    				{
    					int delta = 1ll * C[j][k] * C[m - j][i - k] % djq;
    					if (k & 1) pl = (pl - delta + djq) % djq;
    					else pl = (pl + delta) % djq;
    				}
    				ans[i] = (1ll * I * pl % djq * cnt[j] + ans[i]) % djq;
    			}
    	}
    	for (int i = 0; i <= m; i++) printf("%d ", 1ll * ans[i] * orz % djq);
    	return puts(""), 0;
    }
    

    F

    题意

    • 给定 (n) 个节点的树,(m) 条路径和一个 (k)

    • 求有多少对路径的交至少包含 (k) 条边

    • (2le n,mle 1.5 imes10^5)(1le kle n)

    做法:分类讨论+倍增+BIT+线段树

    • 任选一个根,先考虑相交的两条路径 LCA 不同的情况

    • 此时可以把一条路径拆成两条((s_i)(lca_i)(t_i)(lca_i))来看待

    • 下面设拆完之后的路径为 ((up_i,down_i))(up_i) 的深度较小

    • 考虑当 (dep_{up_i}<dep_{up_j}) 时,第 (i) 条和第 (j) 条路径交集至少为 (k) 当且仅当 (up_j) 沿着 (down_j) 的方向走 (k) 步之后还在路径 ((down_i,up_i))

    • 用倍增处理出每个 (up_i) 沿着 (down_i) 的方向走 (k) 步之后到达的点,用 DFS序+差分+BIT 进行单点加和路径查询即可

    • 再考虑 LCA 相同的情况,设这个 LCA 为 (u),这时又分两种:

    • (1)设对于所有的 (i) 都有 (s_i) 的 DFS 序小于 (t_i),则 (s_i)(s_j) 都不为 (u) 且在 (u) 的同一棵子树内,(t_i)(t_j) 也一样

    • (2)反之

    • 先考虑(2),设路径 (i)((x_i,u)) 部分和路径 (j)((x_j,u)) 部分有交集((x_i,x_j) 为路径 (i,j) 的端点之一)

    • 同样地,这相当于 (u) 沿着 (x_i) 向下走 (k) 步和沿着 (x_j) 向下走 (k) 步到达的点相同,也可以拆成两条之后用和之前类似的方法处理

    • 而对于(1),考虑 (v=lca(s_i,s_j)),方案合法当且仅当:

    • (1)(u)(v) 的严格祖先

    • (2)(dep_v-dep_uge k)(v) 朝着 (t_i)(dep_v-dep_u+1) 步之后的节点子树内包含 (t_j)

    • (3)(dep_v-dep_u<k)(v) 朝着 (t_i)(k) 步之后的节点子树内包含 (t_j)

    • 这三个条件中(1)满足且(2)(3)满足一者

    • 如果 (i) 的取值集合和 (j) 的取值集合给定(不交),则可以建立 (n) 棵动态开点线段树,维护每个 LCA 的路径的 (t)

    • 把所有 (j) 插入到第 (lca_j) 棵线段树的 (dfn_{t_j}) 位置之后,对于每个 (i) 查询第 (lca_i) 棵线段树上某个节点的子树和即可

    • 回到原问题,可以 dsu-on-tree:对这棵树每个非叶节点找出一个 preferred child(即设 (cnt_u=sum_i[s_i=u]),preferred child 为 (cnt_u) 的和最大的子树),然后 dfs 的过程中,先递归轻儿子并把线段树上的东西清掉,然后递归重儿子,这时不要把线段树上的东西清掉,把重子树以外的所有路径的 (s) 加入并统计答案

    • 期间可用一个 set 维护当前子树内的所有路径

    • (O(mlog^2m+nlog n))

    • 本题的巧妙之处就在于,使用了从交点处移动 (k) 步的方法,来判断两条路径的交长度是否 (ge k)

    代码

    #include <bits/stdc++.h>
    
    template <class T>
    inline void read(T &res)
    {
    	res = 0; bool bo = 0; char c;
    	while (((c = getchar()) < '0' || c > '9') && c != '-');
    	if (c == '-') bo = 1; else res = c - 48;
    	while ((c = getchar()) >= '0' && c <= '9')
    		res = (res << 3) + (res << 1) + (c - 48);
    	if (bo) res = ~res + 1;
    }
    
    typedef long long ll;
    typedef std::set<int>::iterator it;
    
    const int N = 15e4 + 5, M = N << 1, L = 1e7 + 5, E = 20;
    
    int n, m, k, ecnt, nxt[M], adj[N], go[M], times, dfn[N], dep[N], fa[N][E],
    s[N], t[N], l[N], p[N], A[N], sze[N], cnt[N], son[N], rt[N], ToT, top, stk[M];
    ll ans;
    std::set<int> orz[N];
    std::vector<int> a[N], b[N];
    
    struct node
    {
    	int lc, rc, sum;
    } T[L];
    
    void change(int l, int r, int pos, int v, int &p)
    {
    	if (!p) p = ++ToT; T[p].sum += v;
    	if (l == r) return;
    	int mid = l + r >> 1;
    	if (pos <= mid) change(l, mid, pos, v, T[p].lc);
    	else change(mid + 1, r, pos, v, T[p].rc);
    }
    
    int ask(int l, int r, int s, int e, int p)
    {
    	if (!p || e < l || s > r) return 0;
    	if (s <= l && r <= e) return T[p].sum;
    	int mid = l + r >> 1;
    	return ask(l, mid, s, e, T[p].lc) + ask(mid + 1, r, s, e, T[p].rc);
    }
    
    void change(int x, int v)
    {
    	for (; x <= n; x += x & -x)
    		A[x] += v;
    }
    
    void sub(int u) {change(dfn[u], 1); change(dfn[u] + sze[u], -1);}
    
    int ask(int x)
    {
    	int res = 0;
    	for (; x; x -= x & -x) res += A[x];
    	return res;
    }
    
    inline bool comp(int a, int b)
    {
    	return dep[l[a]] > dep[l[b]] || (dep[l[a]] == dep[l[b]] && l[a] < l[b]);
    }
    
    void add_edge(int u, int v)
    {
    	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
    	nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
    }
    
    void dfs(int u, int fu)
    {
    	dep[u] = dep[fa[u][0] = fu] + (sze[u] = 1);
    	for (int i = 0; i < 17; i++) fa[u][i + 1] = fa[fa[u][i]][i];
    	dfn[u] = ++times;
    	for (int e = adj[u], v; e; e = nxt[e])
    		if ((v = go[e]) != fu) dfs(v, u), sze[u] += sze[v];
    }
    
    int lca(int u, int v)
    {
    	if (dep[u] < dep[v]) std::swap(u, v);
    	for (int i = 17; i >= 0; i--)
    		if (dep[fa[u][i]] >= dep[v])
    			u = fa[u][i];
    	if (u == v) return u;
    	for (int i = 17; i >= 0; i--)
    		if (fa[u][i] != fa[v][i])
    			u = fa[u][i], v = fa[v][i];
    	return fa[u][0];
    }
    
    int J(int u, int k)
    {
    	for (int i = 17; i >= 0; i--)
    		if ((k >> i) & 1) u = fa[u][i];
    	return u;
    }
    
    void init(int u, int fu)
    {
    	int mx = -1;
    	for (int e = adj[u], v; e; e = nxt[e])
    		if ((v = go[e]) != fu)
    		{
    			init(v, u); cnt[u] += cnt[v];
    			if (cnt[v] > mx) mx = cnt[v], son[u] = v;
    		}
    }
    
    void wtf(int u, int i)
    {
    	if (dfn[l[i]] >= dfn[u] || dfn[u] >= dfn[l[i]] + sze[l[i]]) return;
    	int len = dep[u] + dep[t[i]] - dep[l[i]] * 2;
    	if (len < k || t[i] == l[i]) return;
    	int v = dep[u] - dep[l[i]] >= k ? J(t[i], dep[t[i]] - dep[l[i]] - 1)
    		: J(t[i], len - k);
    	ans += ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]);
    	if (ask(1, n, dfn[v], dfn[v] + sze[v] - 1, rt[l[i]]));
    }
    
    void DFS(int u, int fu)
    {
    	for (int e = adj[u], v; e; e = nxt[e])
    		if ((v = go[e]) != fu && v != son[u])
    		{
    			DFS(v, u);
    			for (it x = orz[v].begin(); x != orz[v].end(); x++)
    				change(1, n, dfn[t[*x]], -1, rt[l[*x]]);
    		}
    	if (son[u]) DFS(son[u], u);
    	for (it x = orz[u].begin(); x != orz[u].end(); x++)
    		wtf(u, *x), change(1, n, dfn[t[*x]], 1, rt[l[*x]]);
    	if (son[u])
    	{
    		for (int e = adj[u], v; e; e = nxt[e])
    		{
    			if ((v = go[e]) == fu || v == son[u]) continue;
    			for (it x = orz[v].begin(); x != orz[v].end(); x++) wtf(u, *x);
    			for (it x = orz[v].begin(); x != orz[v].end(); x++)
    				change(1, n, dfn[t[*x]], 1, rt[l[*x]]), orz[son[u]].insert(*x);
    		}
    		for (it x = orz[u].begin(); x != orz[u].end(); x++)
    			orz[son[u]].insert(*x);
    		std::swap(orz[u], orz[son[u]]);
    	}
    }
    
    int main()
    {
    	int x, y;
    	read(n); read(m); read(k);
    	for (int i = 1; i < n; i++) read(x), read(y), add_edge(x, y);
    	dfs(1, 0);
    	for (int i = 1; i <= m; i++)
    	{
    		read(s[i]); read(t[i]);
    		if (dfn[s[i]] > dfn[t[i]]) std::swap(s[i], t[i]);
    		l[i] = lca(s[i], t[i]); p[i] = i;
    		orz[s[i]].insert(i); cnt[s[i]]++; a[l[i]].push_back(i);
    	}
    	std::sort(p + 1, p + m + 1, comp);
    	for (int i = 1; i <= m;)
    	{
    		int nxt = i;
    		while (nxt <= m && l[p[i]] == l[p[nxt]]) nxt++;
    		for (int j = i; j < nxt; j++)
    		{
    			int x = p[j], u = s[x], v = t[x], w = l[x];
    			ans += ask(dfn[u]) + ask(dfn[v]) - ask(dfn[w]) * 2;
    		}
    		for (int j = i; j < nxt; j++)
    		{
    			int x = p[j], u = s[x], v = t[x], w = l[x];
    			if (dep[u] - dep[w] >= k) sub(J(u, dep[u] - dep[w] - k));
    			if (dep[v] - dep[w] >= k) sub(J(v, dep[v] - dep[w] - k));
    		}
    		i = nxt;
    	}
    	memset(A, 0, sizeof(A));
    	for (int u = 1; u <= n; u++)
    	{
    		for (int i = 0; i < a[u].size(); i++)
    		{
    			int x = a[u][i];
    			if (dep[s[x]] - dep[u] >= k)
    			{
    				ans += A[y = J(s[x], dep[s[x]] - dep[u] - k)]++; stk[++top] = y;
    				if (t[x] != u) b[J(t[x], dep[t[x]] - dep[u] - 1)].push_back(y);
    			}
    			if (dep[t[x]] - dep[u] >= k)
    			{
    				ans += A[y = J(t[x], dep[t[x]] - dep[u] - k)]++; stk[++top] = y;
    				if (s[x] != u) b[J(s[x], dep[s[x]] - dep[u] - 1)].push_back(y);
    			}
    		}
    		while (top--) A[stk[top + 1]] = 0; top = 0;
    		for (int e = adj[u], v; e; e = nxt[e])
    		{
    			if ((v = go[e]) == fa[u][0]) continue;
    			for (int i = 0; i < b[v].size(); i++)
    				ans -= A[y = b[v][i]]++, stk[++top] = y;
    			while (top--) A[stk[top + 1]] = 0; top = 0;
    		}
    	}
    	init(1, 0); DFS(1, 0);
    	return std::cout << ans << std::endl, 0;
    }
    
  • 相关阅读:
    URL模块之parse方法
    结合GET(),POST()实现一个简单、完整的服务器
    Node.js初探之实现能向前台返回东西的简单服务器
    float和position
    回归博客园·共享onload事件
    百度地图api的用法
    美丽数列
    低位值
    删括号
    牛牛找工作
  • 原文地址:https://www.cnblogs.com/xyz32768/p/12738207.html
Copyright © 2011-2022 走看看