zoukankan      html  css  js  c++  java
  • 2020-2021 “Orz Panda” Cup Programming Contest G题(树形结构)

    题目传送门

    题目大意:给点一颗包含 (n)个节点的无根树,有 (m)次询问,每次询问给出两个点 (u)(v),要求计算

    [sum_{r=1}^{n}d_{r}(u,v) ]

    (d_{r}(u,v))是以 (r)为根的树上 (u)(v)的“美丽路径”,它的定义为:

    [d_{r}(u,v)=dis(u,lca_{r}(u,v)) imes dis(v,lca_{r}(u,v)) ]

    其中 (lca_{r}(u,v))是以节点 (r)为根的树中,点 (u)和点 (v)的最近公共祖先。(dis(u,v))等于 (u)(v)之间最短路径的边数。

    输入:第一行输入 (n,m),接下来 (n-1)行给出连边情况,接下来 (m)行代表 (m)组询问。

    输出:对于每个询问输出答案对998244353取模

    数据范围(1 leq n,m leq 1e5)

    分析:令节点 (1)为根,简化问题。考虑要算的东西,发现它只与 (u)(v)的路径上的节点以及这些节点的“分支节点”有关。不明白的话可以画图具体算一下。考虑点 (u)(lca)上的节点 (u_{1},u_{2}...u_{k}),假设 (u_{p})(u_{k})的“分支节点”,那么无论是以 (u_{k})为根还是以 (u_{p})为根, (lca(u,v))都等于 (u_{k}),也就是说可以把 (u_{k})的“分支节点”对答案的贡献累加到 (u_{k})上。假设原本 (u_{k})对答案的贡献为 (w),那么现在就等于 ((num+1) cdot w)(num)是“分支节点”的个数,设 (siz[x])是以 (1)为根的树中以 (x)为根的子树大小,那么 (num=siz[u_{k}]-siz[u_{k-1}]),设 (u,v)之间的距离为 (dis)(dis=dep[u]+dep[v]-2 imes dep[lca])。那么 (u)(lca)上的节点 (u_{1},u_{2}...u_{k})对答案的贡献就等于

    [sum_{r=1}^{k}(siz[u_{k}]-siz[u_{k-1}]) imes (dep[u]-dep[u_{k}]) imes (dis-(dep[u]-dep[u_{k}])) ]

    把它拆成 (8)项,分别计算就好,求下前缀和就可以 (O(1))计算。对于 (v)(lca)的那部分贡献同理计算。另外 (lca)对答案的贡献需要另算。

    #include<cstdio>
    typedef long long ll;
    const int N = 1e5 + 5;
    const int mod = 998244353;
    
    int n, m, cnt, son_u, son_v;
    int head[N], dep[N], son[N], fa[N], top[N];
    ll d_siz[N], d2_siz[N], fa_d_siz[N], fa_d2_siz[N], siz[N];
    // son_u表示u到lca路径上离lca最近的点,son_v同理
    // d_siz[x] = dep[x] * siz[x]
    // d2_siz[x] = dep[x] * dep[x] * siz[x]
    // fa_d_siz[x] = dep[fa[x]] * siz[x]
    // da_d2_siz[x] = dep[fa[x]] * dep[fa[x]] * siz[x]
    
    struct Edge{
    	int nex, to;
    }e[N << 1];
    
    inline ll max(ll a, ll b) { return a > b ? a : b; }
    inline void add(int a, int b) { e[++cnt] = {head[a], b};  head[a] = cnt; }
    
    void dfs1(int u, int f){
    	dep[u] = dep[f] + 1, fa[u] = f, siz[u] = 1;
    	for(int i = head[u]; i; i = e[i].nex){
    		int to = e[i].to;
    		if(to == f)  continue;
    		dfs1(to, u);
    		if(siz[to] > siz[son[u]])  son[u] = to;
    		siz[u] += siz[to];
    	}
    }
    
    void dfs2(int u, int ttop){
    	top[u] = ttop;
    	if(son[u])  dfs2(son[u], ttop);
    	for(int i = head[u]; i; i = e[i].nex){
    		int to = e[i].to;
    		if(to == fa[u] || to == son[u])  continue;
    		dfs2(to, to);
    	}
    }
    
    void dfs3(int u, int f){
    	d_siz[u] = (1LL * dep[u] * siz[u] + d_siz[f]) % mod;
    	d2_siz[u] = (1LL * dep[u] * dep[u] % mod * siz[u] + d2_siz[f]) % mod;
    	fa_d_siz[u] = (1LL * dep[fa[u]] * siz[u] + fa_d_siz[f]) % mod;
    	fa_d2_siz[u] = 	(1LL * dep[fa[u]] * dep[fa[u]] % mod * siz[u] + fa_d2_siz[f]) % mod;
    	for(int i = head[u]; i; i = e[i].nex){
    		int to = e[i].to;
    		if(to == f)  continue;
    		dfs3(to, u);
    	}
    }
    
    // 找lca和son_u,son_v
    int get_lca(int u, int v){
    	while(top[u] != top[v]){
    		if(dep[top[u]] > dep[top[v]])  son_u = top[u], u = fa[top[u]];
    		else  son_v = top[v], v = fa[top[v]];
    	}
    	if(dep[u] > dep[v])  son_u = son[v];
    	else  son_v = son[u];
    	return dep[u] > dep[v] ? v : u;
    }
    
    ll cal(int u, int v, ll *p){
    	// u或v等于0说明son_u不存在,返回0
    	return (dep[u] < dep[v] || v == 0 || u == 0) ? 0 : p[u] - p[v];
    }
    
    int main(){
    	scanf("%d%d", &n, &m);
    	for(int i = 1, u, v; i < n; ++i){
    		scanf("%d%d", &u, &v);
    		add(u, v),  add(v, u);
    	}
    	dfs1(1, 0);
    	dfs2(1, 1);
    	dfs3(1, 0);
    	for(int i = 1, u, v; i <= m; ++i){
    		scanf("%d%d", &u, &v);
    		ll lca = get_lca(u, v),  dis = dep[u] + dep[v] - (dep[lca] << 1);
    		if(u == lca)  son_u = 0;
    		if(v == lca)  son_v = 0;
    		ll ans = 1LL * (n - siz[son_u] - siz[son_v]) * (dep[u] - dep[lca]) % mod * (dep[v] - dep[lca]) % mod;  // lca的贡献
    		ans -= ((dis - dep[u]) * cal(fa[u], lca, d_siz) + (dis - dep[v]) * cal(fa[v], lca, d_siz)) % mod;
    		ans += ((dis - dep[u]) * dep[u] % mod * max(0, siz[son_u] - siz[u]) + (dis - dep[v]) * dep[v] % mod * max(0, siz[son_v] - siz[v])) % mod;
    		ans += ((dis - dep[u]) * cal(u, son_u, fa_d_siz) + (dis - dep[v]) * cal(v, son_v, fa_d_siz)) % mod;
    		ans -= cal(fa[u], lca, d2_siz) + cal(fa[v], lca, d2_siz);
    		ans += cal(u, son_u, fa_d2_siz) + cal(v, son_v, fa_d2_siz);
    		ans += (dep[u] * cal(fa[u], lca, d_siz) + dep[v] * cal(fa[v], lca, d_siz)) % mod;
    		ans -= (dep[u] * cal(u, son_u, fa_d_siz) + dep[v] * cal(v, son_v, fa_d_siz)) % mod;
    		printf("%lld
    ", (ans % mod + mod) % mod);
    	}
    	return 0;
    }
    
    你只有十分努力,才能看上去毫不费力。
  • 相关阅读:
    Building fresh packages卡很久
    后端阿里代码扫描
    npm 使用淘宝镜像
    git镜像
    mysql安装8.0.18
    idea2019.2.2版本破解
    JDK下载很慢
    解决GitHub下载速度慢下载失败的问题
    Hashtable多线程遍历问题
    2-18 求组合数 注:代码有问题找不出哪儿错了
  • 原文地址:https://www.cnblogs.com/214txdy/p/14075078.html
Copyright © 2011-2022 走看看