一道树的直径
树网的核 BZOJ原题链接
树网的核 洛谷原题链接
消防 BZOJ原题链接
消防 洛谷原题链接
一份代码四倍经验,爽
显然要先随便找一条直径,然后直接枚举核的两个端点,对每一次枚举的核遍历核上的每个点,用(dfs)求出核外节点到核的最大值即可,时间复杂度为(O(n^3)),这在(NOIP)的原数据范围下是可以过的,但对于数据加强版就必须要优化了。
发现当枚举到直径上的某个点时,核的另一端在不超过(s)的前提下显然越远越好。这样就直接优化掉一个(n)了,但我们还可以继续优化。
设直径上的点为(u_1,u_2,dots,u_t),当前枚举到的核的两端点为(x_i,x_j)。
根据直径的最长性,我们可以发现对于该核的偏心距实际上就是(max{maxlimits_{k=1}^{t}{d[u_k]},dis(u_1,x_i),dis(x_j,u_t)}),数组(d)表示直径外节点(不经过直径上的点)到(u_k)的最大值,(dis)表示两点间的距离。
而(maxlimits_{k=1}^{t}{d[u_k]})显然是个定值,至于(dis),我们可专门剖出直径上的所有边,然后用在枚举核的左端点时用两个变量维护即可,时间复杂度(O(n))。
#include<cstdio>
using namespace std;
const int N = 5e5 + 10;
struct dd {
int dis, x;
};
dd D[N], a[N];
int fi[N], di[N << 1], da[N << 1], ne[N << 1], l, ma;
bool v[N];
inline int re()
{
int x = 0;
char c = getchar();
bool p = 0;
for (; c<'0' || c>'9'; c = getchar())
p |= c == '-';
for (; c >= '0'&&c <= '9'; c = getchar())
x = x * 10 + (c - '0');
return p ? -x : x;
}
inline int maxn(int x, int y)
{
return x > y ? x : y;
}
inline int minn(int x, int y)
{
return x < y ? x : y;
}
inline void add(int x, int y, int z)
{
di[++l] = y;
da[l] = z;
ne[l] = fi[x];
fi[x] = l;
}
void dfs(int x, int fa, int dis, int la)
{
int i, y;
if (dis > ma)
{
ma = dis;
D[0].x = x;
}
D[x].x = fa;
D[x].dis = la;
for (i = fi[x]; i; i = ne[i])
{
y = di[i];
if (y != fa)
dfs(y, x, dis + da[i], da[i]);
}
}
void dfs_2(int x, int dis)
{
int i, y;
v[x] = 1;
if (dis > ma)
ma = dis;
for (i = fi[x]; i; i = ne[i])
{
y = di[i];
if (!v[y])
dfs_2(y, dis + da[i]);
}
}
int main()
{
int i, j, n, m, x, y, z, s = 0, k = 0, tail = 0, head = 0, an = 1e9;
n = re();
m = re();
for (i = 1; i < n; i++)
{
x = re();
y = re();
z = re();
add(x, y, z);
add(y, x, z);
}
dfs(1, 0, 0, 0);
ma = 0;
dfs(D[0].x, 0, 0, 0);
for (i = D[0].x; i; i = D[i].x)
{
v[i] = 1;
a[++k].x = i;
a[k].dis = D[i].dis;
}
ma = 0;
for (i = 1; i <= k; i++)
dfs_2(a[i].x, 0);
for (j = 1; j < n; j++)
if (s + a[j].dis <= m)
s += a[j].dis;
else
break;
for (i = j; i < n; i++)
tail += a[i].dis;
an = minn(an, maxn(ma, maxn(head, tail)));
for (i = 1; i < n; i++)
{
s -= a[i].dis;
head += a[i].dis;
for (; j < n; j++)
if (s + a[j].dis <= m)
{
s += a[j].dis;
tail -= a[j].dis;
}
else
break;
an = minn(an, maxn(ma, maxn(head, tail)));
}
printf("%d", an);
return 0;
}