Description
Solution
设点(i)到根的第一条边的颜色为(col_i),根到点(i)的路径上的颜色和是(sum_i),经过观察发现(col_i)相同的不在同一个子树里的两个点之间的简单路径拼接后的答案是(sum_i + sum_j - w_{col_i})。这是因为中间的一段会被重复算两次,而如果(col_i)和(col_j)不同,答案就是(sum_i + sum_j)。
这样启示我们对于(col_i)相同的点的子树和不同的点的子树分别算答案,点(i)和点(j)之间的简单路径长度就是(dep_i + dep_j)(假设根的(dep)是(0))。
首先根据(col_i)排序,将(col_i)相同的放在一起处理,对于相同的(col_i)就是找(dep)满足条件且(sum)最大的点。处理一个点的子树时,可以对子树中的每个点按照深度从小到大进行排序,这样对应的满足条件的点的深度就是递减的。发现这时对应的点的选择是满足单调性的,就可以用单调队列维护。具体操作是对于之前(col_i)相同但是不在同一个子树内的点开一个桶记录深度为(x)时的最大(sum)值,每次处理一个新的子树时先把单调队列清空,然后按深度从大到小依次加入单调队列,这样就可以处理(col_i)相同的子树的答案,处理完之后将这个子树的答案加入桶里。
对于不同的子树,有类似的求法。如果直接按照(dep)排序有点慢,直接用(bfs)就可以优化掉这部分复杂度。
每次按照深度将点加入单调队列的复杂度是之前的最大深度,如果第一个处理的就是深度最大的子树,复杂度就会爆炸,所以要按照子树中最大深度进行排序。对于不同(col_i)的子树集合,也要按照其中出现的最大深度排序。
由于根节点的深度为(0),所以如果在(col_i)相同的子树中处理到根节点的路径,会多减去(w_{col_i}),于是在对于(col_i)不同的时候算答案的时候算上即可。
接下来证明一下按照深度排序之后每次加入单调队列的复杂度。
对于(col_i)相同的块内我们每次加入单调队列时的复杂度是之前所有出现过的节点的最深深度,由于我们按照最深节点的深度排序,所以每次处理的子树大小一定是递增的。那么如果深度增加了(x),则子树大小至少增加了(x),如果新的子树不是一条链,那么增加的子树大小还会更多,显然是一条链的时候增加的节点数是最少的,那么如果每次处理的子树都是一条链,增加的深度就是节点个数。也就是说,每次只会产生(O()节点个数())的复杂度,而最后一次最深的就不需要加入队列了;查询的时候同理,最小的深度就不需要加入队列了。这样的话复杂度最坏就是(2 imes sum dep - mindep + maxdep),因为(sum dep)是小于等于(n)的,所以复杂度是(O(n))级别的。
这样每一个(col_i)相同的块会产生最多(2 imes siz)的复杂度,(col_i)不同的最多会产生(2 imes allsiz)的复杂度,相加最多是(4 imes allsiz)的复杂度,还是(O(n))级别的。
对于排序,我们发现每条边只会被排序一次,设每次重心相连的变数为(s_i)那么总复杂度是(sum log_{s_i}),这是小于点分治的复杂度(log_n)的。
所以总复杂度是点分治的(n log_n)。
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 200000;
const int INF = 214748364;
int dep[N +50], col[N + 50], q[N + 50], ans1[N + 50], ans2[N + 50], coldep[N +50], siz[N + 50], had[N + 50], w[N + 50], vis[N + 50], sum[N + 50], tmp[N + 50], inq[N + 50], ans[N + 50], px[N + 50], pxnum;
int num, n, m, maxx, root, all, maxx1, maxx2, L, R, answer = -INF, tmpans;
struct Node
{
int next, to, dis;
} edge[N * 2 + 50];
struct Tmp
{
int id, val;
};
void Addedge(int u, int v, int w)
{
edge[++num] = (Node){had[u], v, w};
had[u] = num;
return;
}
struct Queue
{
int head, tail;
Tmp q[N + 50];
void Clear() { head = 1; tail = 0;}
int Empty() { return head > tail; }
Tmp Top() { return q[head]; }
void Pop() { head++;}
void Push(Tmp x)
{
while (head <= tail && x.val >= q[tail].val) tail--;
q[++tail] = x;
}
} q1;
void Findroot(int x, int fa)
{
int maxson = 0;
siz[x] = 1;
for (int i = had[x]; i; i = edge[i].next)
{
int v = edge[i].to;
if (v == fa || vis[v]) continue;
Findroot(v, x);
siz[x] += siz[v];
maxson = max(maxson, siz[v]);
}
maxson = max(maxson, all - siz[x]);
if (maxson < maxx) root = x, maxx = maxson;
return;
}
void Getdep(int x, int len, int fa)
{
dep[x] = len;
for (int i = had[x]; i; i = edge[i].next)
{
int v = edge[i].to;
if (v == fa || vis[v]) continue;
Getdep(v, len + 1, x);
dep[x] = max(dep[x], dep[v]);
}
return;
}
int Cmp(int a, int b)
{
return col[a] == col[b] ? dep[a] < dep[b] : coldep[col[a]] < coldep[col[b]];
}
void Bfs2(int x)
{
// cout << "同块内" << endl;
q1.Clear();
int now = maxx2, head = 0, tail = 1;
q[tail] = x; inq[x] = 1; dep[x] = 1;
while (head < tail)
{
int u = q[++head];
while (now > 0 && now + dep[u] >= L) q1.Push((Tmp){now, ans2[now]}), now--;
while (!q1.Empty() && dep[u] + q1.Top().id > R) q1.Pop();
// if (!q1.Empty()) cout << q1.Top().id << " " << u << " " << dep[u] << endl;
if (!q1.Empty()) ans[u] = sum[u] + q1.Top().val - w[col[x]]; else ans[u] = -INF;
answer = max(answer, ans[u]);
for (int i = had[u]; i; i = edge[i].next)
{
int v = edge[i].to;
if (inq[v] || vis[v]) continue;
dep[v] = dep[u] + 1; inq[v] = 1; q[++tail] = v;
col[v] = edge[i].dis;
sum[v] = sum[u] + (col[v] == col[u] ? 0 : w[edge[i].dis]);
}
}
// for (int i = 1; i <= tail; i++) cout << q[i] << " " << sum[q[i]] << " " << dep[q[i]] << endl;
for (int i = now + 1; i <= min(dep[q[tail]], R); i++) ans2[i] = -INF;
maxx2 = max(maxx2, min(dep[q[tail]], R));
for (int i = 1; i <= tail && dep[q[i]] <= R; i++) ans2[dep[q[i]]] = max(ans2[dep[q[i]]], sum[q[i]]);
// cout << maxx2 << endl;
// for (int i = 1; i <= maxx2; i++) cout << ans2[i] << " ";
// cout << endl;
for (int i = 1; i <= tail; i++) inq[q[i]] = 0;
return;
}
void Bfs1(int tot)
{
if (!tot) return;
// cout << tot << endl;
//// cout << "不同块" << endl;
q1.Clear();
int now = maxx1;
int head = 0, tail = 0;
for (int i = 1; i <= tot; i++) q[++tail] = tmp[i], dep[tmp[i]] = 1, inq[tmp[i]] = 1;
while (head < tail)
{
int u = q[++head];
// cout << u << " " << dep[u] << endl;
while (now >= 0 && now + dep[u] >= L) q1.Push((Tmp){now, ans1[now]}), now--;
while (!q1.Empty() && q1.Top().id + dep[u] > R) q1.Pop();
// if (!q1.Empty())cout << q1.Top().val << " " << u << " " << dep[u] << endl;
if (!q1.Empty()) ans[u] = sum[u] + q1.Top().val; else ans[u] = -INF;
answer = max(answer, ans[u]);
for (int i = had[u]; i; i = edge[i].next)
{
int v = edge[i].to;
if (vis[v] || inq[v]) continue;
inq[v] = 1; dep[v] = dep[u] + 1; q[++tail] = v;
col[v] = edge[i].dis;
sum[v] = sum[u] + (col[u] == col[v] ? 0 : w[edge[i].dis]);
}
}
// for (int i = 1; i <= tail; i++) cout << q[i] << " " << sum[q[i]] << " " << dep[q[i]] << endl;
for (int i = now + 1; i <= min(R, dep[q[tail]]); i++) ans1[i] = -INF;
maxx1 = max(maxx1, min(R, dep[q[tail]]));
for (int i = 1; i <= tail && dep[q[i]] <= R; i++) ans1[dep[q[i]]] = max(ans1[dep[q[i]]], sum[q[i]]);
for (int i = 1; i <= tail; i++) inq[q[i]] = 0; maxx2 = 0;
return;
}
void Solve(int x)
{
// cout << "重心:" << x << endl;
vis[x] = 1;
for (int i = had[x]; i; i = edge[i].next) coldep[edge[i].dis] = 0, col[edge[i].to] = edge[i].dis;
for (int i = had[x]; i; i = edge[i].next)
{
int v = edge[i].to;
if (vis[v]) continue;
px[++pxnum] = v;
Getdep(v, 1, x);
coldep[edge[i].dis] = max(coldep[edge[i].dis], dep[v]);
}
sort(px + 1, px + pxnum + 1, Cmp);
maxx1 = 0; maxx2 = 0;
int precol = 0, tot = 0; ans1[0] = 0;
for (int i = 1; i <= pxnum; i++)
{
int v = px[i];
if (vis[v]) continue;
if (col[v] != precol) precol = col[v], Bfs1(tot), tot = 0;
sum[v] = w[col[v]];
Bfs2(v); tmp[++tot] = v;
}
Bfs1(tot);
pxnum = 0;
for (int i = had[x]; i; i = edge[i].next)
{
int v = edge[i].to;
if (vis[v]) continue;
all = siz[v]; maxx = INF; Findroot(v, x); Solve(root);
}
return;
}
void Read(int &x)
{
x = 0; int p = 0; char st = getchar();
while (st < '0' || st > '9') p = (st == '-'), st = getchar();
while (st >= '0' && st <= '9') x = (x << 1) + (x << 3) + st - '0', st = getchar();
x = p ? -x : x;
return;
}
int main()
{
Read(n); Read(m); Read(L); Read(R);
for (int i = 1; i <= m; i++) Read(w[i]);
for (int i = 1, u, v, w; i <= n - 1; i++)
{
Read(u); Read(v); Read(w);
Addedge(u, v, w);
Addedge(v, u, w);
}
all = n; maxx = INF; Findroot(1, 0); Solve(root);
printf("%d", answer);
return 0;
}