zoukankan      html  css  js  c++  java
  • 树上操作

    题目描述

    思路

    树链剖分更新子树x,由于线段树节点的顺序由dfs产生,所以更新的线段树区间为[dfn[x], dfn[x] + size[x] - 1]

    代码

    #include <cstdio>
    #include <cstring>
    #define lc k<<1
    #define rc k<<1|1
     
    const int MAX = 1e5 + 10;
    int n, m, ot, oa[100];
    int head[MAX], ver[MAX << 1], nt[MAX << 1], ht;
    int wt[MAX];
    int fa[MAX], dep[MAX], size[MAX], son[MAX];
    int top[MAX], dfn[MAX], tr[MAX], dt;
    long long sum[MAX << 2], add[MAX << 2], ans;
    char showStr[100];
    
    inline int read() {
    	int s = 0, f = 1;
    	char ch = getchar();
    	while (ch < '0' || ch > '9') {
    		if (ch == '-') f = -1; 
    		ch = getchar();
    	}
    	while (ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
    	return s * f;
    }
    
    inline void write(long long x) {
    	ot = 0;
    	if (x == 0) { putchar('0'); return; } 
    	if (x < 0) putchar('-'), x = -x;
    	while (x) oa[++ot] = x % 10 + '0', x /= 10;
    	while (ot) putchar(oa[ot--]);
    }
    
    void add_edge(int x, int y) {
    	nt[++ht] = head[x], head[x] = ht, ver[ht] = y;
    }
    
    void dfs1(int x, int u) {
    	fa[x] = u;
    	dep[x] = dep[u] + 1;
    	size[x] = 1;
    	for (int i = head[x], j; i; i = nt[i]) {
    		j = ver[i];
    		if (j == u) continue;
    		dfs1(j, x);
    		size[x] += size[j];
    		if (size[j] > size[son[x]]) son[x] = j;
    	}
    }
    
    void dfs2(int x, int u) {
    	top[x] = u;
    	dfn[x] = ++dt;
    	tr[dt] = x;
    	if (son[x]) dfs2(son[x], u);
    	for (int i = head[x], j; i; i = nt[i]) {
    		j = ver[i];
    		if (!dfn[j]) dfs2(j, j);
    	}
    }
    
    void build(int k, int l, int r) {
    	if (l == r) { sum[k] = wt[tr[l]]; return; }
    	int mid = l + r >> 1;
    	build(lc, l, mid);
    	build(rc, mid + 1, r);
    	sum[k] = sum[lc] + sum[rc];
    }
    
    void pushdown(int k, int l, int r, int mid) {
    	if (add[k] == 0) return;
    	sum[lc] += (mid - l + 1) * add[k], add[lc] += add[k];
    	sum[rc] += (r - mid) * add[k], add[rc] += add[k];
    	add[k] = 0;
    }
    
    void change(int k, int l, int r, int x, int y, int z) {
    	if (x <= l && r <= y) { 
    		sum[k] += (long long)(r - l + 1) * z; 
    		add[k] += z; 
    		return; 
    	}
    	int mid = l + r >> 1;
    	pushdown(k, l, r, mid);
    	if (x <= mid) change(lc, l, mid, x, y, z);
    	if (y > mid) change(rc, mid + 1, r, x, y, z);
    	sum[k] = sum[lc] + sum[rc];
    } 
    
    void swap(int &x, int &y) {
    	int t = x;
    	x = y, y = t;
    }
    
    void query(int k, int l, int r, int x, int y) {
    	// printf("query: %d %d %d %d %d %d
    ", k, l, r, x, y, ans);
    	if (x <= l && r <= y) { ans += sum[k]; return; }
    	int mid = l + r >> 1;
    	pushdown(k, l, r, mid);
    	if (x <= mid) query(lc, l, mid, x, y);
    	if (y > mid) query(rc, mid + 1, r, x, y);
    }
    
    void ask(int x, int y) {
    	ans = 0LL;
    	int fx = top[x], fy = top[y];
    	while (fx != fy) {
    		if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
    		query(1, 1, n, dfn[fx], dfn[x]);
    		x = fa[fx], fx = top[x];
    	}
    	if (dep[x] > dep[y]) swap(x, y);
    	query(1, 1, n, dfn[x], dfn[y]);
    }
    
    void showArray(int * arr) {
    	puts(showStr);
    	for (int i = 1; i <= n; ++i) printf("%2d ", arr[i]);
    	puts("");
    }
    void show() {
    	printf("n:%d m:%d
    ", n, m);
    	for (int i = 1; i <= n; ++i) {
    		printf("%d:", i);
    		for (int j = head[i]; j; j = nt[j]) {
    			printf("%d ", ver[j]);
    		}
    		puts("");
    	}
    	strcpy(showStr, "wt  :"), showArray(wt);
    	strcpy(showStr, "fa  :"), showArray(fa);
    	strcpy(showStr, "size:"), showArray(size);
    	strcpy(showStr, "dep :"), showArray(dep);
    	strcpy(showStr, "son :"), showArray(son);
    	strcpy(showStr, "dfn :"), showArray(dfn);
    	strcpy(showStr, "tr  :"), showArray(tr);
    	strcpy(showStr, "top :"), showArray(top);
    }
    
    int main() {
    	n = read(), m = read();
    	for (int i = 1; i <= n; ++i) wt[i] = read();
    	for (int i = 1, a, b; i < n; ++i) {
    		a = read(), b = read(), 
    		add_edge(a, b), add_edge(b, a);
    	}
    	
    	dfs1(1, 0);
    	dfs2(1, 1);
    	// show();
    	build(1, 1, n);
    	for (int i = 1, j, a, b; i <= m; ++i) {
    		j = read();
    		// printf("%d ", j);
    		switch(j) {
    			case 1: 
    				a = read(), b = read();
    				// printf("%d %d
    ", a, b);
    				change(1, 1, n, dfn[a], dfn[a], b);
    				break;
    			case 2: 
    				// printf("%d %d
    ", a, b);
    				a = read(), b = read();
    				change(1, 1, n, dfn[a], dfn[a] + size[a] - 1, b);
    				break;
    			case 3: 
    				a = read();
    				// printf("%d
    ", a);
    				ask(1, a);
    				printf("%lld
    ", ans);
    				// write(ans);
    				puts("");
    				break;
    		}
    	}
    	
    	return 0;
    }
    
  • 相关阅读:
    PostgreSQL 安装和使用
    动态sql
    知识储备
    java空和非空判断
    我的第一篇博客
    正式工作:PreparedStatement 参与的
    mysql part2DML(数据操作语言)
    DCL(权限 ,用户)
    DQL(数据查询语言)
    准备工作:Eclipse 导入 mysql连接java 的jar包
  • 原文地址:https://www.cnblogs.com/liuzz-20180701/p/11524477.html
Copyright © 2011-2022 走看看