[题目链接]
[算法]
树上差分 , 时间复杂度 : O(N + M)
[代码]
#include<bits/stdc++.h> using namespace std; #define MAXN 1000010 struct edge { int to,nxt; } e[MAXN << 1]; int n,m,tot,timer; bool visited[MAXN]; int f[MAXN],seq[MAXN],dfn[MAXN],fa[MAXN],dep[MAXN],w[MAXN],s[MAXN],t[MAXN], w1[MAXN],w2[MAXN],ans[MAXN],head[MAXN],size[MAXN],lca[MAXN],cnt1[MAXN],cnt2[MAXN]; vector< pair<int,int> > q[MAXN],tag1[MAXN],tag2[MAXN],pos[MAXN]; template <typename T> inline void read(T &x) { int f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } inline void addedge(int u,int v) { tot++; e[tot] = (edge){v,head[u]}; head[u] = tot; } inline int get_root(int x) { if (f[x] == x) return x; return f[x] = get_root(f[x]); } inline void tarjan(int u,int father) { dfn[u] = ++timer; seq[timer] = u; size[u] = 1; visited[u] = true; for (int i = 0; i < (int)q[u].size(); i++) { int v = q[u][i].first , id = q[u][i].second; if (visited[v]) lca[id] = get_root(v); } for (int i = head[u]; i; i = e[i].nxt) { int v = e[i].to; if (v == father) continue; dep[v] = dep[u] + 1; tarjan(v,u); fa[v] = f[v] = u; size[u] += size[v]; } } int main() { read(n); read(m); for (int i = 1; i < n; i++) { int u,v; read(u); read(v); addedge(u,v); addedge(v,u); } for (int i = 1; i <= n; i++) read(w[i]); for (int i = 1; i <= n; i++) f[i] = i; for (int i = 1; i <= m; i++) { read(s[i]); read(t[i]); q[s[i]].push_back(make_pair(t[i],i)); q[t[i]].push_back(make_pair(s[i],i)); } tarjan(1,-1); for (int i = 1; i <= n; i++) w1[i] = dep[i] + w[i]; for (int i = 1; i <= n; i++) w2[i] = dep[i] - w[i]; for (int i = 1; i <= m; i++) { if (lca[i] == s[i]) { tag2[t[i]].push_back(make_pair(dep[s[i]],1)); tag2[fa[s[i]]].push_back(make_pair(dep[s[i]],-1)); } else if (lca[i] == t[i]) { tag1[s[i]].push_back(make_pair(dep[s[i]],1)); tag1[fa[t[i]]].push_back(make_pair(dep[s[i]],-1)); } else { tag1[s[i]].push_back(make_pair(dep[s[i]],1)); tag1[fa[lca[i]]].push_back(make_pair(dep[s[i]],-1)); tag2[t[i]].push_back(make_pair(2 * dep[lca[i]] - dep[s[i]],1)); tag2[lca[i]].push_back(make_pair(2 * dep[lca[i]] - dep[s[i]],-1)); } } for (int i = 1; i <= n; i++) { pos[dfn[i] - 1].push_back(make_pair(i,-1)); pos[dfn[i] + size[i] - 1].push_back(make_pair(i,1)); } int val = 2 * n; for (int i = 1; i <= n; i++) { int u = seq[i]; for (int j = 0; j < (int)tag1[u].size(); j++) cnt1[tag1[u][j].first + val] += tag1[u][j].second; for (int j = 0; j < (int)tag2[u].size(); j++) cnt2[tag2[u][j].first + val] += tag2[u][j].second; for (int j = 0; j < (int)pos[i].size(); j++) { ans[pos[i][j].first] += cnt1[w1[pos[i][j].first] + val] * pos[i][j].second; ans[pos[i][j].first] += cnt2[w2[pos[i][j].first] + val] * pos[i][j].second; } } for (int i = 1; i < n; i++) printf("%d ",ans[i]); printf("%d ",ans[n]); return 0; }