显然线段树合并(我也不知道哪来这么多显然)
考虑将每条路径拆成两条路径 s -> lca 和 t -> lca 。
对于前一种路径上的某一点i,希望在时刻 w[i] 经过它,那么就有
dep[s] - dep[i] = w[i]
移项可得:
dep[s] = w[i] + dep[i]
然后发现dep[s]可以被看做已知条件,那么根据常用套路,在点s将线段树dep[s]处的值 + 1,在lca处还原。
回溯的过程中通过 w[i] + dep[i] 查答案就好了。
对于后一种路径,考虑在点t的时候将线段树的某个位置 + 1,在lca计算答案之前还原。那么就有:
dep[s] + dep[i] - 2 * dep[lca] = w[i]
然后移项得到:
dep[s] - 2 * dep[lca] = w[i] - dep[i]
继续扔进线段树里维护和查询。
(前年做这道题的时候完全没有思路,去年发现原来这么水。。)
Code
1 /** 2 * luogu 3 * Problem#1600 4 * Accepted 5 * Time: 6321ms 6 * Memory: 172394k 7 */ 8 #include <bits/stdc++.h> 9 using namespace std; 10 #define smin(_a, _b) _a = min(_a, _b) 11 #define smax(_a, _b) _a = max(_a, _b) 12 #define fi first 13 #define sc second 14 typedef pair<int, int> pii; 15 typedef bool boolean; 16 template<typename T> 17 inline void readInteger(T& u) { 18 static char x; 19 while(!isdigit(x = getchar())); 20 for(u = x - '0'; isdigit(x = getchar()); u = u * 10 + x - '0'); 21 } 22 23 typedef class SegTreeNode { 24 public: 25 int val; 26 SegTreeNode *l, *r; 27 28 SegTreeNode(int val = 0, SegTreeNode* l = NULL, SegTreeNode* r = NULL):val(val), l(l), r(r) { } 29 30 inline void pushUp() { 31 val = l->val + r->val; 32 } 33 }SegTreeNode; 34 35 #define LIMIT 2000000 36 SegTreeNode pool[LIMIT + 1]; 37 SegTreeNode *top = pool; 38 SegTreeNode null(0, &null, &null); 39 40 SegTreeNode* newnode() { 41 if(top >= pool + LIMIT) 42 return new SegTreeNode(0, &null, &null); 43 *top = SegTreeNode(0, &null, &null); 44 return top++; 45 } 46 47 #define null &null 48 49 void merge(SegTreeNode*& a, SegTreeNode* b) { 50 if(a == null) { 51 a = b; 52 return; 53 } 54 if(b == null) return; 55 a->val += b->val; 56 merge(a->l, b->l); 57 merge(a->r, b->r); 58 } 59 60 typedef class SegTree { 61 public: 62 SegTreeNode* root; 63 64 SegTree():root(null) { } 65 66 void update(SegTreeNode*& node, int l, int r, int idx, int val) { 67 if(node == null) 68 node = newnode(); 69 if(l == idx && r == idx) { 70 node->val += val; 71 return; 72 } 73 int mid = (l + r) >> 1; 74 if(idx <= mid) 75 update(node->l, l, mid, idx, val); 76 else 77 update(node->r, mid + 1, r, idx, val); 78 node->pushUp(); 79 } 80 81 int query(SegTreeNode*& node, int l, int r, int idx) { 82 if(node == null) return 0; 83 if(l == idx && r == idx) 84 return node->val; 85 int mid = (l + r) >> 1; 86 if(idx <= mid) 87 return query(node->l, l, mid, idx); 88 return query(node->r, mid + 1, r, idx); 89 } 90 }SegTree; 91 92 typedef class Query { 93 public: 94 int s; 95 int t; 96 int lca; 97 98 Query(int s = 0, int t = 0, int lca = 0):s(s), t(t), lca(lca) { } 99 }Query; 100 101 int n, m; 102 int* wss; 103 Query* qs; 104 vector<int> *q; 105 vector<int> *g; 106 107 inline void init() { 108 readInteger(n); 109 readInteger(m); 110 wss = new int[(n + 1)]; 111 qs = new Query[(m + 1)]; 112 q = new vector<int>[(n + 1)]; 113 g = new vector<int>[(n + 1)]; 114 for(int i = 1, u, v; i < n; i++) { 115 readInteger(u); 116 readInteger(v); 117 g[u].push_back(v); 118 g[v].push_back(u); 119 } 120 for(int i = 1; i <= n; i++) 121 readInteger(wss[i]); 122 for(int i = 1, s, t; i <= m; i++) { 123 readInteger(s); 124 readInteger(t); 125 qs[i] = Query(s, t, 0); 126 q[s].push_back(i); 127 if(s != t) 128 q[t].push_back(i); 129 } 130 } 131 132 int* f; 133 int* dep; 134 int find(int x) { 135 return (f[x] == x) ? (x) : (f[x] = find(f[x])); 136 } 137 138 void tarjan(int node, int fa) { 139 f[node] = node; 140 dep[node] = dep[fa] + 1; 141 for(int i = 0; i < (signed)g[node].size(); i++) { 142 int& e = g[node][i]; 143 if(e == fa) continue; 144 tarjan(e, node); 145 f[e] = node; 146 } 147 for(int i = 0; i < (signed)q[node].size(); i++) { 148 int id = q[node][i]; 149 if(f[qs[id].s] && f[qs[id].t] && !qs[id].lca) 150 qs[id].lca = find(f[(qs[id].s == node) ? (qs[id].t) : (qs[id].s)]); 151 } 152 } 153 154 int *ans; 155 vector<int>* ls; 156 SegTreeNode* dfs1(int node, int fa) { 157 SegTree st; 158 for(int i = 0; i < (signed)q[node].size(); i++) { 159 Query &aq = qs[q[node][i]]; 160 if(aq.s == node) { 161 st.update(st.root, 1, n, dep[aq.s], 1); 162 ls[aq.lca].push_back(q[node][i]); 163 } 164 } 165 for(int i = 0; i < (signed)g[node].size(); i++) { 166 int& e = g[node][i]; 167 if(e == fa) continue; 168 merge(st.root, dfs1(e, node)); 169 } 170 if(dep[node] + wss[node] <= n) 171 ans[node] = st.query(st.root, 1, n, dep[node] + wss[node]); 172 else 173 ans[node] = 0; 174 for(int i = 0; i < (signed)ls[node].size(); i++) 175 st.update(st.root, 1, n, dep[qs[ls[node][i]].s], -1); 176 ls[node].clear(); 177 return st.root; 178 } 179 180 SegTreeNode* dfs2(int node, int fa) { 181 SegTree st; 182 for(int i = 0; i < (signed)q[node].size(); i++) { 183 Query &aq = qs[q[node][i]]; 184 if(aq.t == node) { 185 st.update(st.root, -n, n, dep[aq.s] - 2 * dep[aq.lca], 1); 186 ls[aq.lca].push_back(q[node][i]); 187 } 188 } 189 for(int i = 0; i < (signed)g[node].size(); i++) { 190 int& e = g[node][i]; 191 if(e == fa) continue; 192 merge(st.root, dfs2(e, node)); 193 } 194 for(int i = 0; i < (signed)ls[node].size(); i++) 195 st.update(st.root, -n, n, dep[qs[ls[node][i]].s] - 2 * dep[qs[ls[node][i]].lca], -1); 196 ans[node] += st.query(st.root, -n, n, wss[node] - dep[node]); 197 return st.root; 198 } 199 200 inline void solve() { 201 f = new int[(n + 1)]; 202 dep = new int[(n + 1)]; 203 dep[0] = 0; 204 memset(f, 0, sizeof(int) * (n + 1)); 205 tarjan(1, 0); 206 ans = new int[(n + 1)]; 207 ls = new vector<int>[(n + 1)]; 208 dfs1(1, 0); 209 top = pool; 210 for(int i = 1; i <= n; i++) 211 assert(ls[i].empty()); 212 dfs2(1, 0); 213 for(int i = 1; i <= n; i++) 214 printf("%d ", ans[i]); 215 } 216 217 int main() { 218 init(); 219 solve(); 220 return 0; 221 }
线段树不优秀,跑得太慢了。ccf老年机应该会跑T。
我们发现,这个本质上是在dfs序某些位置某一下标修改一个值,然后询问区间某一下标的和。
这个完全可以用前缀和相减,而没必要出动线段树。
比较简单的写法就是:访问到一个点,记录一下它询问的下标的值,然后再递归它的子树,最后加上差值。
Code
1 /** 2 * uoj 3 * Problem#261 4 * Accepted 5 * Time: 1236ms 6 * Memory: 59108k 7 */ 8 #include <iostream> 9 #include <cstdlib> 10 #include <cstdio> 11 #include <vector> 12 using namespace std; 13 typedef bool boolean; 14 15 const int N = 3e5 + 3, N2 = N << 1; 16 17 #define pii pair<int, int> 18 #define fi first 19 #define sc second 20 21 template <typename T> 22 void pfill(T* ps, const T* ped, T val) { 23 for ( ; ps != ped; *ps = val, ps++); 24 } 25 26 template <typename T> 27 class MapManager { 28 public: 29 int* h; 30 vector< pair<T, int> > vs; 31 32 MapManager() { } 33 MapManager(int n) { 34 h = new int[(n + 1)]; 35 pfill(h, h + n + 1, -1); 36 } 37 38 void insert(int p, T dt) { 39 vs.push_back(make_pair(dt, h[p])); 40 h[p] = (signed) vs.size() - 1; 41 } 42 43 pair<T, int>& operator [] (int p) { 44 return vs[p]; 45 } 46 47 }; 48 49 int n, m; 50 int *dep; 51 int *wts, *res; 52 int *lcas, *us, *vs; 53 boolean *exi; 54 MapManager<int> g; 55 MapManager<pii> qlca; // fi: another node. sc: id 56 MapManager<int> as, rs; 57 MapManager<pii> ms; // sc: val 58 59 inline void init() { 60 scanf("%d%d", &n, &m); 61 wts = new int[(n + 1)]; 62 res = new int[(n + 1)]; 63 exi = new boolean[(m + 1)]; 64 pfill(res + 1, res + n + 1, 0); 65 pfill(exi, exi + m + 1, true); 66 g = MapManager<int>(n); 67 for (int i = 1, u, v; i < n; i++) { 68 scanf("%d%d", &u, &v); 69 g.insert(u, v), g.insert(v, u); 70 } 71 for (int i = 1; i <= n; i++) 72 scanf("%d", wts + i); 73 us = new int[(n + 1)]; 74 vs = new int[(n + 1)]; 75 lcas = new int[(n + 1)]; 76 qlca = MapManager<pii>(n); 77 for (int i = 1, u, v; i <= m; i++) { 78 scanf("%d%d", &u, &v); 79 qlca.insert(u, pii(v, i)); 80 qlca.insert(v, pii(u, i)); 81 us[i] = u, vs[i] = v; 82 } 83 } 84 85 int *uf; 86 boolean *vis; 87 int dfs_clock; 88 89 int find(int x) { 90 return (uf[x] == x) ? (x) : (uf[x] = find(uf[x])); 91 } 92 93 void tarjan(int p, int fa, int dp) { 94 vis[p] = true, dep[p] = dp; 95 for (int i = g.h[p], e; ~i; i = g[i].sc) { 96 if ((e = g[i].fi) == fa) 97 continue; 98 tarjan(e, p, dp + 1); 99 uf[find(e)] = p; 100 } 101 102 for (int i = qlca.h[p]; ~i; i = qlca[i].sc) { 103 pii d = qlca[i].fi; 104 if (vis[d.fi] && exi[d.sc]) { 105 exi[d.sc] = false; 106 lcas[d.sc] = find(d.fi); 107 } 108 } 109 } 110 111 int bucket[N2]; 112 113 void put(int p, int val) { 114 (p < 0) ? (p += N) : (0); 115 bucket[p] += val; 116 } 117 118 int get(int p) { 119 (p < 0) ? (p += N) : (0); 120 return bucket[p]; 121 } 122 123 void dfs1(int p, int fa) { 124 int tmp = get(wts[p] + dep[p]); 125 for (int i = as.h[p]; ~i; i = as[i].sc) 126 put(as[i].fi, 1); 127 for (int i = g.h[p], e; ~i; i = g[i].sc) { 128 if ((e = g[i].fi) == fa) 129 continue; 130 dfs1(e, p); 131 } 132 res[p] += get(wts[p] + dep[p]) - tmp; 133 for (int i = rs.h[p]; ~i; i = rs[i].sc) 134 put(rs[i].fi, -1); 135 } 136 137 void dfs2(int p, int fa) { 138 int tmp = get(wts[p] - dep[p]); 139 for (int i = ms.h[p]; ~i; i = ms[i].sc) 140 put(ms[i].fi.fi, ms[i].fi.sc); 141 for (int i = g.h[p], e; ~i; i = g[i].sc) { 142 if ((e = g[i].fi) == fa) 143 continue; 144 dfs2(e, p); 145 } 146 res[p] += get(wts[p] - dep[p]) - tmp; 147 } 148 149 inline void solve() { 150 uf = new int[(n + 1)]; 151 dep = new int[(n + 1)]; 152 vis = new boolean[(n + 1)]; 153 for (int i = 1; i <= n; i++) 154 uf[i] = i; 155 pfill(vis + 1, vis + n + 1, false); 156 tarjan(1, 0, 1); 157 158 delete[] vis; 159 delete[] exi; 160 161 as = MapManager<int>(n); 162 rs = MapManager<int>(n); 163 ms = MapManager<pii>(n); 164 165 for (int i = 1; i <= m; i++) { 166 int u = us[i], v = vs[i], g = lcas[i]; 167 as.insert(u, dep[u]); 168 rs.insert(g, dep[u]); 169 ms.insert(v, pii(dep[u] - 2 * dep[g], 1)); 170 ms.insert(g, pii(dep[u] - 2 * dep[g], -1)); 171 } 172 173 dfs1(1, 0); 174 dfs2(1, 0); 175 for (int i = 1; i <= n; i++) 176 printf("%d ", res[i]); 177 } 178 179 int main() { 180 init(); 181 solve(); 182 return 0; 183 }