zoukankan      html  css  js  c++  java
  • 【ybt金牌导航5-4-4】【luogu P4842】城市旅行

    城市旅行

    题目链接:ybt金牌导航5-4-4 / luogu P4842

    题目大意

    给你一棵树,要你维护一些操作:
    删除某条边(如果两点间不联通就不管)
    添加某条边(如果两点间已联通就不管)
    给某条路径上的点点权加一个值(如果两点不连通就不管)
    询问某条路径上任选两个点,这两个点之间路径的权值和的期望。(如果两点不连通就输出 -1)

    思路

    看到加边删边找路径,自然想到 LCT。
    然后我们考虑如何维护输出的值。

    那容易想到,我们可以补考虑选的点,而是考虑一个点,有多少个路径会经过它。
    那容易想到对于长度为 (sz) 的路径,对于第 (i) 个点,有 (i imes(sz-i+1)) 个路径经过了它。(左右各选一个)
    那它的贡献就是 (i imes(sz-i+1) imes a_i)

    那总贡献就是:(dfrac{sumlimits_{i=1}^{sz}i imes(sz-i+1) imes a_i}{C_{sz+1}^2})
    (下面就是 (C_{sz+1}^2) 因为选的两个点可以是同一个点)

    那我们要维护的就是它了,分母很好搞,直接每次算就行,问题是分子。
    那我们考虑 DP,已经求出了左右子树,要怎么搞到它。
    我们设左子树的大小是 (b_0),然后序列是 (b_1,b_2,...,b_{b_0}),右子树大小是 (c_0),序列是 (c_1,c_2,...,c_{c_0})

    那对于左子树里面的第 (i) 个点,它在左子树里面的贡献就是 (i imes(b_0-i+1) imes b_i),它在这里的贡献就是 (i imes(b_0+c_0+1-i+1) imes b_i)。作差,就是 (i imes b_i imes (c_0+1))
    那左子树的贡献就是它原本的贡献加上 ((c_0+1) imessumlimits_{i=1}^{b_0}i imes b_i)
    那我们发现右边的部分((sumlimits_{i=1}^{b_0}i imes b_i))我们也可以 DP,我是用 (lsum) 数组记录,这里就不讲了,不会的自己看代码。

    那接着右子树用同样的方法:
    原本:(i imes(c_0-i+1) imes c_i)
    现在:((b_0+1+i) imes(b_0+c_0+1-(b_0+1+i)+1) imes c_i)
    (=(b_0+1+i) imes(c_0-i+1) imes c_i)
    差:((b_0+1) imes(c_0-i+1) imes c_i)
    那左子树的贡献就是它原本的贡献加上 ((b_0+1) imessumlimits_{i=1}^{b_0}(c_0-i+1) imes c_i)
    然后右边部分((sumlimits_{i=1}^{b_0}(c_0-i+1) imes c_i))继续 DP,我是用 (rsum) 数组记录。

    接着就是新的点,那这个其实容易,就直接暴力算:(a imes(b_0+1) imes(c_0+1))。(记得加一)

    那查询我们就搞定了,接着,就是修改了。(加边删边就是普通 LCT,不用搞)
    那我们也是懒标记,那每次要怎么改呢?
    首先权值就普通的加,权值和就加上它乘大小。
    接着是 (lsum,rsum),容易看到你每个数每加 (x),值就会多 (x+2x+3x+4x+...),那就是 (x imes (1+sz) imes sz / 2)
    那接着就是 (ans),即期望的分子,那我们可以列出式子:(ans+=sumlimits_{i=1}^{sz}i imes(sz-i+1) imes d)

    然后我们由化简可以得到 (ans+=dfrac{sz(sz+1)(sz+2)}{6} imes d)
    然后就可以搞啦!

    化简过程

    知道的可以不看。
    要搞的东西:(sumlimits_{i=1}^{sz}i imes(sz-i+1)=dfrac{sz(sz+1)(sz+2)}{6})
    首先考虑让其中一项固定:
    (sumlimits_{i=1}^{sz}i imes(sz-i+1)=sumlimits_{i=1}^{sz}i imes sz-sumlimits_{i=1}^{sz}i imes(i-1))
    然后右边部分考虑去括号:(sumlimits_{i=1}^{sz}i imes sz-sumlimits_{i=1}^{sz}(i^2-i))
    分别拿出来:(sz imessumlimits_{i=1}^{sz}i-sumlimits_{i=1}^{sz}i^2+sumlimits_{i=1}^{sz}i)
    然后都可以去掉 (sum)(sz imesfrac{(sz+1) imes sz}{2}-frac{sz(sz+1)(2 imes sz+1)}{6}+frac{(sz+1) imes sz}{2})
    合并一下:(frac{3(sz+1)(sz+1)sz}{6}-frac{sz(sz+1)(2sz+1)}{6})
    (frac{(3sz+3)(sz+1)sz}{6}-frac{(2sz+1)(sz+1)sz}{6})
    (frac{(sz+2)(sz+1)sz}{6})
    然后就好啦!

    可能有人(指我自己)会不知道为什么 (sumlimits_{i=1}^{sz}i^2=dfrac{sz(sz+1)(2sz+1)}{6})
    然后这里也讲讲,这个是用立方差来搞的。
    (x^3-(x-1)^3=x^3-(x^3-3x^2+3x-1)=3x^2-3x+1)
    然后根据这个,我们把 ((n^3-(n-1)^3)+((n-1)^3-(n-2)^3)+...+(2^3-1^3)) 每个都转。
    那互相消掉,就是 (n^3-1^3=(3n^2-3n+1)+(3(n-1)^2-3(n-1)+1)+...+(3 imes2^2-3 imes2+1))
    拆开:(n^3-1=3n^2+3(n-1)^2+...+3 imes2^2-(3n+3(n-1)+...+3 imes2+(n-1)))
    然后继续搞:(n^3-1=3(n^2+(n-1)^2+...+2^2)-3(n+(n-1)+...+2)+(n-1))
    移项:(3(n^2+(n-1)^2+...+2^2+1^2)=n^3-1-(n-1)+frac{3(n+2)(n-1)}{2}+3 imes1^2)
    (3(n^2+(n-1)^2+...+2^2+1^2)=n^3-n+3+frac{3(n+2)(n-1)}{2})
    ((n^2+(n-1)^2+...+2^2+1^2)=frac{2n^3-2n+6+3(n+2)(n-1)}{6})
    (=frac{2n^3-2n+6+3(n^2+n-2)}{6}=frac{2n^3+3n^2+n}{6}=frac{n(2n^2+3n+1)}{6}=frac{n(n+1)(2n+1)}{6})
    然后就有了。

    代码

    #include<cstdio>
    #include<algorithm>
    #define ll long long
    
    using namespace std;
    
    int n, m, sz[50001], d;
    int l[50001], r[50001], fa[50001];
    ll ans[50001], val[50001], lz[50001];
    ll lsum[50001], rsum[50001], sum[50001];
    bool lzs[50001];
    int op, x, y;
    
    //LCT
    bool nrt(int x) {
    	return l[fa[x]] == x || r[fa[x]] == x;
    }
    
    bool ls(int x) {
    	return l[fa[x]] == x;
    }
    
    void up(int x) {//把推公式推出来的放上去
    	sz[x] = sz[l[x]] + sz[r[x]] + 1;
    	sum[x] = sum[l[x]] + sum[r[x]] + val[x];
    	
    	//DP 维护 lsum rsum
    	lsum[x] = lsum[l[x]] + val[x] * (sz[l[x]] + 1) + (lsum[r[x]] + sum[r[x]] * (sz[l[x]] + 1));
    	rsum[x] = rsum[r[x]] + val[x] * (sz[r[x]] + 1) + (rsum[l[x]] + sum[l[x]] * (sz[r[x]] + 1));
    	ans[x] = ans[l[x]] + ans[r[x]] + (sz[r[x]] + 1) * lsum[l[x]] + (sz[l[x]] + 1) * rsum[r[x]] + val[x] * (sz[l[x]] + 1) * (sz[r[x]] + 1); 
    }
    
    void downa(int x, ll Val) {
    	val[x] += Val;
    	lz[x] += Val;
    	sum[x] += Val * sz[x];
    	
    	lsum[x] += Val * (1 + sz[x]) * sz[x] / 2;
    	rsum[x] += Val * (sz[x] + 1) * sz[x] / 2;
    	ans[x] += Val * sz[x] * (sz[x] + 1) * (sz[x] + 2) / 6;
    }
    
    void downs(int x) {
    	swap(l[x], r[x]);
    	swap(lsum[x], rsum[x]);//记得这个也要 swap
    	lzs[x] ^= 1;
    } 
    
    void down(int x) {
    	if (lzs[x]) {
    		if (l[x]) downs(l[x]);
    		if (r[x]) downs(r[x]);
    		lzs[x] = 0;
    	}
    	if (lz[x]) {
    		if (l[x]) downa(l[x], lz[x]);
    		if (r[x]) downa(r[x], lz[x]);
    		lz[x] = 0;
    	}
    }
    
    void down_line(int x) {
    	if (nrt(x)) down_line(fa[x]);
    	down(x);
    }
    
    void rotate(int x) {
    	int y = fa[x];
    	int z = fa[y];
    	int b = (ls(x) ? r[x] : l[x]);
    	if (z && nrt(y)) (ls(y) ? l[z] : r[z]) = x;
    	if (ls(x)) r[x] = y, l[y] = b;
    		else l[x] = y, r[y] = b;
    	fa[x] = z;
    	fa[y] = x;
    	if (b) fa[b] = y;
    	up(y);
    }
    
    void Splay(int x) {
    	down_line(x);
    	while (nrt(x)) {
    		if (nrt(fa[x])) {
    			if (ls(x) == ls(fa[x])) rotate(fa[x]);
    				else rotate(x);
    		}
    		rotate(x);
    	}
    	up(x);
    }
    
    void access(int x) {
    	int lst = 0;
    	for (; x; x = fa[x]) {
    		Splay(x);
    		
    		r[x] = lst;
    		up(x);
    		lst = x;
    	}
    }
    
    void make_root(int x) {
    	access(x);
    	Splay(x);
    	downs(x);
    }
    
    int find_root(int x) {
    	access(x);
    	Splay(x);
    	while (l[x]) {
    		down(x);
    		x = l[x];
    	}
    	Splay(x);
    	return x;
    }
    
    int split(int x, int y) {
    	make_root(x);
    	if (find_root(y) != x) return -1; 
    	access(y);
    	Splay(y);
    	return y;
    }
    
    void cut(int x, int y) {//连和断的时候都要判断连通
    	make_root(x);
    	if (find_root(y) != x) return ;
    	access(y);
    	Splay(y);
    	l[y] = 0;
    	fa[x] = 0;
    }
    
    void link(int x, int y) {
    	make_root(x);
    	if (find_root(y) != x)
    		fa[x] = y;
    }
    
    ll gcd(ll x, ll y) {
    	if (!y) return x;
    	return gcd(y, x % y);
    }
    
    void write(ll x, ll y) {
    	ll GCD = gcd(x, y);
    	x /= GCD; y /= GCD;
    	printf("%lld/%lld
    ", x, y);
    }
    
    int main() {
    	scanf("%d %d", &n, &m);
    	for (int i = 1; i <= n; i++) scanf("%d", &val[i]), sz[i] = 1, sum[i] = lsum[i] = rsum[i] = ans[i] = val[i];
    	for (int i = 1; i < n; i++) {
    		scanf("%d %d", &x, &y);
    		link(x, y);
    	}
    	
    	while (m--) {
    		scanf("%d %d %d", &op, &x, &y);
    		
    		if (op == 1) {
    			cut(x, y);
    			continue;
    		}
    		if (op == 2) {
    			link(x, y);
    			continue;
    		}
    		if (op == 3) {
    			scanf("%d", &d);
    			if (find_root(x) != find_root(y)) continue;//记得操作前要判断是否连通
    			x = split(x, y);
    			downa(x, d);
    			continue;
    		}
    		if (op == 4) {
    			if (find_root(x) != find_root(y)) {printf("-1
    ");continue;}
    			x = split(x, y);
    			write(ans[x], 1ll * sz[x] * (sz[x] + 1) / 2);
    			continue;
    		}
    	}
    	
    	return 0;
    }
    
  • 相关阅读:
    开始几天的基本学习
    从这个博客开始我的机器学习深度学习之路
    剑指Offer:面试题3——二维数组中的查找(java实现)
    HIVE配置文件
    C++ 之旅:前言
    leetcode 349:两个数组的交集I
    python学习(三):matplotlib学习
    python学习(二):python基本语法
    Android环境搭建
    LeetCode:237
  • 原文地址:https://www.cnblogs.com/Sakura-TJH/p/YBT_JPDH_5-4-4.html
Copyright © 2011-2022 走看看