zoukankan      html  css  js  c++  java
  • 牛客练习赛81D 小Q与树

    dsu on tree

    题目链接

    点我跳转

    题目大意

    给定一棵包含 (n) 个节点的树,每个节点有个权值 (a_i)
    (sum_{u=1}^nsum_{v=1}^nmin(a_u,a_v)dis(u,v))

    解题思路

    对于节点 (u)

    • 记权值小于 (a_u) 的节点有 (a_{x1},a_{x2},a_{x3},...,a_{xcnt1})
    • 记权值大于等于 (a_u) 的节点有 (a_{y1},a_{y2},...,a_{ycnt2})

    那么节点 (u) 对答案的贡献为:

    1. (a_u imes(dep_u + dep_{x1} - 2 imes dep_{lca})+a_u imes(dep_u + dep_{x2} - 2 imes dep_{lca})+...)
    2. (a_{y1} imes(dep_u + dep_{y1} - 2 imes dep_{lca})+a_{y2} imes(dep_u + dep_{y2} - 2 imes dep_{lca})+...)

    即:

    1. (a_u imes cnt1 imes (dep_u -2 imes dep_{lca}) + a_u imes(deq_{(x1+...+xcnt1)}))
    2. (a_{(y1+...+ycnt2)} imes (dep_u - 2 imes dep_{lca})+a_{(y1+...+ycnt2)} imes dep_{(y1+..+ycnt2)})

    定义 (rt) 为当前子树的根,那么 (lca = rt)

    开四棵权值树状数组,分别用来维护 (cnt)(dep)(a_i)(a_i imes dep_i)

    然后跑一遍 (dsu~on~tree) 即可

    AC_Code

    #include<bits/stdc++.h>
    #define int long long
    using namespace std;
    template<typename T>void read(T &res)
    {
    	bool flag=false;
    	char ch;
    	while(!isdigit(ch=getchar()))(ch=='-')&&(flag=true);
    	for(res=ch-48; isdigit(ch=getchar()); res=(res<<1)+(res<<3)+ch - 48);
    	flag&&(res=-res);
    }
    template<typename T>void Out(T x)
    {
    	if(x<0)putchar('-'),x=-x;
    	if(x>9)Out(x/10);
    	putchar(x%10+'0');
    }
    const int N = 2e5 + 10 , mod = 998244353;
    int n , ans , a[N] , dep[N] , sz[N] , HH , hson[N] , M;
    struct Edge{
    	int nex , to;
    } edge[N << 1];
    int head[N] , TOT;
    void add_edge(int u , int v)
    {
    	edge[++ TOT].nex = head[u];
    	edge[TOT].to = v;
    	head[u] = TOT;
    }
    struct TR{
    	int tr[N];
    	int lowbit(int x){
    		return x & (-x);
    	}
    	void add(int pos , int val)
    	{
    		while(pos <= M)
    		{
    			tr[pos] = (tr[pos] + val + mod) % mod;
    			pos += lowbit(pos);
    		}
    	}
    	int query(int pos)
    	{
    		int res = 0;
    		while(pos)
    		{
    			res += tr[pos];
    			res %= mod;
    			pos -= lowbit(pos);
    		}
    		return res;
    	}
    	int get_sum(int L , int R){
    		return (query(R) - query(L - 1) + mod) % mod;
    	}
    } tree1 , tree2 , tr1 , tr2;
    vector<int>vec;
    int get_id(int x){
    	return lower_bound(vec.begin() , vec.end() , x) - vec.begin() + 1;
    }
    void dfs(int u , int far)
    {
    	dep[u] = dep[far] + 1 , sz[u] = 1;
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far) continue ;
    		dfs(v , u);
    		sz[u] += sz[v];
    		if(sz[v] > sz[hson[u]]) hson[u] = v;
    	}
    }
    void change(int u , int far , int val)
    {
    	tree1.add(a[u] , dep[u] * val);
    	tree2.add(a[u] , vec[a[u] - 1] * dep[u] * val);
    	tr1.add(a[u] , val);
    	tr2.add(a[u] , val * vec[a[u] - 1]);
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == HH) continue ;
    		change(v , u , val);
    	}
    }
    void calc(int u , int far , int rt)
    {
    	int cnt = tr1.get_sum(a[u] , M);
    	int sum = tree1.get_sum(a[u] , M);
    	int mi = vec[a[u] - 1];
    		ans += mi * dep[u] * cnt + mi * sum;
    		ans -= mi * cnt * 2 * dep[rt];
    		ans = (ans + mod) % mod;
    	sum = tree2.get_sum(1 , a[u] - 1);
    	cnt = tr2.get_sum(1 , a[u] - 1);
    		ans += sum + cnt * dep[u];
    		ans -= cnt * 2 * dep[rt];
    		ans = (ans + mod) % mod;
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == HH) continue ;
    		calc(v , u , rt);
    	}
    }
    void dsu(int u , int far , int op)
    {
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == hson[u]) continue ;
    		dsu(v , u , 0);
    	}
    	if(hson[u]) dsu(hson[u] , u , 1) , HH = hson[u];
    	for(int i = head[u] ; i ; i = edge[i].nex)
    	{
    		int v = edge[i].to;
    		if(v == far || v == HH) continue;
    		calc(v , u , u) , change(v , u , 1);
    	}
    	int cnt = tr1.get_sum(a[u] , M);
    	int sum = tree1.get_sum(a[u] , M);
    	int mi = vec[a[u] - 1];
    		ans += mi * dep[u] * cnt + mi * sum;
    		ans -= mi * cnt * 2 * dep[u];
    		ans = (ans + mod) % mod;
    	sum = tree2.get_sum(1 , a[u] - 1);
    	cnt = tr2.get_sum(1 , a[u] - 1);
    		ans += sum + cnt * dep[u];
    		ans -= cnt * 2 * dep[u];
    		ans = (ans + mod) % mod;
    	tree1.add(a[u] , dep[u]);
    	tree2.add(a[u] , vec[a[u] - 1] * dep[u]);
    	tr1.add(a[u] , 1);
    	tr2.add(a[u] , vec[a[u] - 1]);
    	HH = 0;
    	if(!op) change(u , far , -1);
    }
    signed main()
    {
    	read(n);
    	for(int i = 1 ; i <= n ; i ++) read(a[i]) , vec.push_back(a[i]);
    	for(int i = 1 ; i <  n ; i ++)
    	{
    		int u , v;
    		read(u) , read(v);
    		add_edge(u , v) , add_edge(v , u);
    	}
    	sort(vec.begin() , vec.end());
    	vec.erase(unique(vec.begin() , vec.end()) , vec.end());
    	for(int i = 1 ; i <= n ; i ++) a[i] = get_id(a[i]);
    	M = vec.size();
    	dfs(1 , 0);
    	dsu(1 , 0 , 1);
    	Out(ans * 2 % mod) , puts("");
    	return 0;
    }
    
    凡所不能将我击倒的,都将使我更加强大
  • 相关阅读:
    seata原理
    activemq 启动时出现错误 Address already in use: JVM_Bind
    高并发第五弹:安全发布对象及单例模式
    高并发第三弹:线程安全原子性
    高并发第一弹:准备阶段 了解高并发
    CentOS7安装PostgreSQL9.4
    高并发第二弹:并发概念及内存模型(JMM)
    高并发第四弹:线程安全性可见性有序性
    设计模式模板方法模式
    设计模式建造者模式(图解,使用场景)
  • 原文地址:https://www.cnblogs.com/StarRoadTang/p/14958220.html
Copyright © 2011-2022 走看看