zoukankan      html  css  js  c++  java
  • SPOJ COT2 Count on a tree II 树上莫队算法

    题意:

    给出一棵(n(n leq 4 imes 10^4))个节点的树,每个节点上有个权值,和(m(m leq 10^5))个询问。
    每次询问路径(u o v)上有多少个权值不同的点。

    分析:

    • 树分块

    首先将树分块,每块的大小为(sqrt{n})左右。
    然后将询问离线处理,按照区间上的莫队算法将询问按块排序。
    这里有一道裸的树分块的题目。

    • 树上的路径转移

    定义(S(u,v))表示路径(u o v)上的点集,定义(igoplus)为集合的对称差,类似于异或运算。
    那么有(S(u,v)=S(root,u) igoplus S(root, v) igoplus LCA(u,v)),有一个(LCA)不方便处理。
    再定义一个(T(u,v)=S(root,u) igoplus S(root, v))
    (T(u_1, v_1) igoplus T(v_1, v_2)=S(root, u_1) igoplus S(root, v_1) igoplus S(root, v_1) igoplus S(root, v_2))
    消去中间两项得到:(T(u_1, v_1) igoplus T(v_1, v_2)=S(root, u_1) igoplus S(root, v_2)=T(u_1, v_2))

    从结论可以看出,由(T(u_1,v_1))(T(u_1, v_2))只需要(igoplus)一个(T(v_1, v_2))
    由于对称差(igoplus)运算满足交换律和结合律,所以再(igoplus)一个(T(u_1, u_2))就得到(T(u_2,v_2))

    假设上次查询的路径为(u o v),我们维护点集(T(u,v))的信息:(in(u))(u)是否在集合中,(cnt(x))集合中权值为(x)点的个数,(diff)权值不同的点数。
    查询的话如果(LCA(u,v))的权值没有出现过,答案就是(diff+1),否则就是(diff)

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <cmath>
    using namespace std;
    
    const int maxn = 40000 + 10;
    const int maxq = 100000 + 10;
    
    struct Edge
    {
    	int v, nxt;
    	Edge() {}
    	Edge(int v, int nxt): v(v), nxt(nxt) {}
    };
    
    int ecnt, head[maxn];
    Edge edges[maxn * 2];
    
    void AddEdge(int u, int v) {
    	edges[ecnt] = Edge(v, head[u]);
    	head[u] = ecnt++;
    }
    
    int n, m;
    
    int a[maxn], b[maxn], tot;
    
    int anc[maxn][20], dep[maxn];
    int group[maxn], blocks, sz;
    
    int S[maxn], top;
    
    void dfs(int u) {
    	int cur = top;
    	for(int i = head[u]; ~i; i = edges[i].nxt) {
    		int v = edges[i].v;
    		if(v == anc[u][0]) continue;
    		anc[v][0] = u;
    		dep[v] = dep[u] + 1;
    		dfs(v);
    		if(top - cur >= sz) {
    			blocks++;
    			while(top != cur) group[S[top--]] = blocks;
    		}
    	}
    	S[++top] = u;
    }
    
    struct Query
    {
    	int u, v, id;
    	bool operator < (const Query& t) const {
    		return group[u] < group[t.u] || (group[u] == group[t.u] && group[v] < group[t.v]);
    	}
    }q[maxq];
    
    void preprocess() {
    	for(int j = 1; (1 << j) < n; j++)
    		for(int i = 1; i <= n; i++) if(anc[i][j-1])
    			anc[i][j] = anc[anc[i][j-1]][j-1];
    }
    
    int LCA(int u, int v) {
    	if(dep[u] < dep[v]) swap(u, v);
    	int log;
    	for(log = 0; (1 << log) < dep[u]; log++);
    	for(int i = log; i >= 0; i--)
    		if(dep[u] - (1<<i) >= dep[v]) u = anc[u][i];
    	if(u == v) return u;
    	for(int i = log; i >= 0; i--)
    		if(anc[u][i] && anc[u][i] != anc[v][i])
    			u = anc[u][i], v = anc[v][i];
    	return anc[u][0];
    }
    
    int cnt[maxn], dif, in[maxn];
    
    void xorvertex(int u) {
    	if(in[u]) { cnt[a[u]]--; if(!cnt[a[u]]) dif--; }
    	else { cnt[a[u]]++; if(cnt[a[u]] == 1) dif++; }
    	in[u] ^= 1;
    }
    
    void xorpath(int u, int v) {
    	if(dep[u] < dep[v]) swap(u, v);
    	while(dep[u] > dep[v]) { xorvertex(u); u = anc[u][0]; }
    	while(u != v) {
    		xorvertex(u); xorvertex(v);
    		u = anc[u][0]; v = anc[v][0];
    	}
    }
    
    int ans[maxq];
    
    int main()
    {
    	scanf("%d%d", &n, &m);
    	for(int i = 1; i <= n; i++) {
    		scanf("%d", a + i);
    		b[i] = a[i];
    	}
    	sort(b + 1, b + 1 + n);
    	tot = unique(b + 1, b + 1 + n) - b - 1;
    	for(int i = 1; i <= n; i++)
    		a[i] = lower_bound(b + 1, b + 1 + tot, a[i]) - b;
    
    	ecnt = 0;
    	memset(head, -1, sizeof(head));
    	for(int i = 1; i < n; i++) {
    		int u, v; scanf("%d%d", &u, &v);
    		AddEdge(u, v); AddEdge(v, u);
    	}
    
    	sz = (int)sqrt(n);
    	dfs(1);
    	while(top) group[S[top--]] = blocks;
            preprocess();
    
    	for(int i = 1; i <= m; i++) {
    		scanf("%d%d", &q[i].u, &q[i].v);
    		q[i].id = i;
    		if(q[i].u > q[i].v) swap(q[i].u, q[i].v);
    	}
    
    	sort(q + 1, q + 1 + m);
    
    	int u = 1, v = 1;
    	for(int i = 1; i <= m; i++) {
    		xorpath(u, q[i].u);
    		xorpath(v, q[i].v);
    		u = q[i].u, v = q[i].v;
    		int lca = LCA(u, v);
    		ans[q[i].id] = dif;
    		if(!cnt[a[lca]]) ans[q[i].id]++;
    	}
    
    	for(int i = 1; i <= m; i++) printf("%d
    ", ans[i]);
    
    	return 0;
    }
    
  • 相关阅读:
    python格式化输出之format用法
    Mybatis插入数据返回主键
    DBC 和 Mybatis连接mysql数据库的时候,设置字符集编码
    工具列表
    Idea的Git如何回退到上一个版本
    mybatis-plus id主键生成的坑
    JAVA 线上故障排查完整套路,从 CPU、磁盘、内存、网络、GC 一条龙!
    DDD-快速hold住业务的利器
    深入理解ThreadLocal的原理和内存泄漏问题
    VUE开发--环境配置
  • 原文地址:https://www.cnblogs.com/AOQNRMGYXLMV/p/5289778.html
Copyright © 2011-2022 走看看