挺裸的点分治
刚开始想用map水过去,然后做p次点分治,然后T到自闭
最后发现可以sort一遍,然后去重,记录每个数出现的次数,这样就可以双指针,不会漏掉了
1 #include <bits/stdc++.h> 2 using namespace std; 3 4 #define N 100010 5 int n, q, Q[N], ans[N]; 6 struct Graph 7 { 8 struct node 9 { 10 int to, nx, w; 11 node() {} 12 node (int to, int nx, int w) : to(to), nx(nx), w(w) {} 13 }a[N << 1]; 14 int head[N], pos; 15 void add(int u, int v, int w) 16 { 17 a[++pos] = node(v, head[u], w); head[u] = pos; 18 a[++pos] = node(u, head[v], w); head[v] = pos; 19 } 20 }G; 21 #define erp(u) for (int it = G.head[u], v = G.a[it].to, w = G.a[it].w; it; it = G.a[it].nx, v = G.a[it].to, w = G.a[it].w) 22 23 bool vis[N]; 24 int root, sum, sze[N], f[N]; 25 void getroot(int u, int fa) 26 { 27 sze[u] = 1, f[u] = 0; 28 erp(u) if (v != fa && !vis[v]) 29 { 30 getroot(v, u); 31 sze[u] += sze[v]; 32 f[u] = max(f[u], sze[v]); 33 } 34 f[u] = max(f[u], sum - sze[u]); 35 if (f[u] < f[root]) root = u; 36 } 37 38 int deep[N], d[N]; 39 void getdeep(int u, int fa) 40 { 41 deep[++deep[0]] = d[u]; 42 erp(u) if (v != fa && !vis[v]) 43 { 44 d[v] = d[u] + w; 45 getdeep(v, u); 46 } 47 } 48 49 int g[N], num[10000010]; 50 void calc(int u, int cost, int opt) 51 { 52 d[u] = cost; deep[0] = 0; 53 getdeep(u, 0); 54 sort(deep + 1, deep + 1 + deep[0]); 55 g[0] = 0; 56 for (int i = 1; i <= deep[0]; ++i) 57 { 58 if (i == 1 || deep[i] != deep[i - 1]) g[++g[0]] = deep[i], num[g[0]] = 1; 59 else ++num[g[0]]; 60 } 61 for (int i = 1; i <= q; ++i) 62 { 63 for (int j = 1; j <= g[0]; ++j) 64 { 65 if (num[j] > 1 && 2 * g[j] == Q[i]) 66 ans[i] += opt * num[j] * (num[j] - 1); 67 } 68 int r = g[0]; 69 for (int l = 1; l < r; ++l) 70 { 71 while (l < r && g[l] + g[r] >= Q[i]) 72 { 73 if (g[l] + g[r] == Q[i]) ans[i] += opt * num[l] * num[r]; 74 --r; 75 } 76 } 77 } 78 } 79 80 void solve(int u) 81 { 82 calc(u, 0, 1); 83 vis[u] = 1; 84 erp(u) if (!vis[v]) 85 { 86 calc(v, w, -1); 87 sum = f[0] = sze[v]; 88 root = 0; 89 getroot(v, u); 90 solve(root); 91 } 92 } 93 94 void Run() 95 { 96 scanf("%d%d", &n, &q); 97 { 98 for (int i = 1, u, v, w; i < n; ++i) 99 { 100 scanf("%d%d%d", &u, &v, &w); 101 G.add(u, v, w); 102 } 103 for (int qq = 1; qq <= q; ++qq) scanf("%d", Q + qq); 104 sum = n; f[0] = n; 105 root = 0; 106 getroot(1, 0); 107 solve(root); 108 for (int i = 1; i <= q; ++i) puts(ans[i] || !Q[i] ? "Yes" : "No"); 109 } 110 } 111 112 int main() 113 { 114 Run(); 115 }