题目
题目链接:https://atcoder.jp/contests/agc007/tasks/agc007_e
一颗(n)个节点的二叉树,每个节点要么有两个儿子要么没有儿子。边有边权。
你从(1)号节点出发,走到一个叶子节点。然后每一天,你可以从当前点走到另一个叶子。最后回到(1)号节点,要求到过所有叶子并且每条边经过恰好两次。
每天的路费是你走过的路径上的边权和,你的公司会为你报销大部分路费,除了你旅行中所用路费最高的,行走路线是从叶子到叶子的那一天的路费。
求你自己最少要付多少路费?
(nleq 131072)。
思路
首先答案显然满足单调性,直接二分后转化为判定。也就是是否存在一个 dfs 序满足路径相邻叶子之间的距离不超过 ( ext{mid})。
设 (f[x][a][b]) 表示点 (x) 为根的子树内,(x) 到第一个叶子的距离为 (a),(x) 到最后一个叶子的距离为 (b),且相邻两个叶子之间距离不超过 ( ext{mid}) 是否可行。
转移显然
看着这复杂度就爆炸。但是我们发现有很多状态是无用的,当 (f[x][a][b]=f[x][c][d]=1(aleq c,bleq d)) 时,(f[x][c][d]) 就是完全不会造成任何贡献的。
所以说我们就不用把所有状态记录下来,直接对每一个点 (x) 维护一个集合 (F_x),里面存上所有合法的二元组 ((a,b)),并且所有二元组都是可能会有贡献的。也就是当所有二元组按照 (a) 升序排序时,(b) 一定是严格降序的。
假设我们确定了经过 (x) 后先走左子树再走右子树,假定两个子树内的元素都是根据 (a) 升序排列,显然对于左子树的一个合法二元组 ((a,b)),右子树中能和它合并的一定是集合的一个前缀。并且我们一定取前缀的最后一位最优。因为合并起来的路径长度只要不超过 ( ext{mid}) 都没有区别,(x) 到左子树第一个叶子路径长度一定是固定的 (a+ ext{dis}_{x,lc}),只需要 (x) 到右子树最后一个叶子的距离最小即可。由于 (b) 是单调减的,所以只需要找到合法前缀的最后一个元素即可。
那么直接双指针扫一下就可以找到每一个二元组在另一个子树中对应的二元组了。但是我们不能直接扔进 (F_x),因为不一定是先走左子树。
所以我们就分开处理先走左右子树的情况,这两种情况合并出来的路径分别是单调的,我们就扔进两个 vector 中,然后进行一次归并就可以保证 (F_x) 也是单调的了。
处理点 (x) 的复杂度是有关其左右子树的合法状态数的。我们发现剔除不合法状态后,(x) 的状态数上界是其两个子树状态数较小值的两倍,所以总状态数是 (O(nlog n)) 的。时间复杂度也就是 (O(nlog nlog (na))) 了。
代码
#include <bits/stdc++.h>
#define mp make_pair
#define ST first
#define ND second
using namespace std;
typedef long long ll;
const int N=140010;
int n,ch[N][2],dis[N][2];
ll mid;
vector<pair<ll,ll> > f[N],g[2];
void ins(int x,int y,int j)
{
int i=f[x].size()-1;
if (i>=0)
{
if (f[x][i].ND<=g[y][j].ND) return;
if (f[x][i].ST==g[y][j].ST) f[x].pop_back();
}
f[x].push_back(g[y][j]);
}
void dfs(int x)
{
if (!ch[x][0]) return (void)(f[x].push_back(mp(0,0)));
int lc=ch[x][0],rc=ch[x][1],d1=dis[x][0],d2=dis[x][1];
dfs(lc); dfs(rc);
for (int i=0,j=0;i<f[lc].size();i++)
{
while (j<f[rc].size() && f[lc][i].ND+f[rc][j].ST+d1+d2<=mid) j++;
if (!j) continue;
g[0].push_back(mp(f[lc][i].ST+d1,f[rc][j-1].ND+d2));
}
for (int i=0,j=0;i<f[rc].size();i++)
{
while (j<f[lc].size() && f[rc][i].ND+f[lc][j].ST+d1+d2<=mid) j++;
if (!j) continue;
g[1].push_back(mp(f[rc][i].ST+d2,f[lc][j-1].ND+d1));
}
for (int i=0,j=0;i<g[0].size() || j<g[1].size();)
{
if (i==g[0].size()) { ins(x,1,j); j++; continue; }
if (j==g[1].size()) { ins(x,0,i); i++; continue; }
if (g[0][i].ST<=g[1][j].ST) ins(x,0,i),i++;
else ins(x,1,j),j++;
}
g[0].clear(); g[1].clear();
f[lc].clear(); f[rc].clear();
}
int main()
{
scanf("%d",&n);
for (int i=2,x,y;i<=n;i++)
{
scanf("%d%d",&x,&y);
if (ch[x][0]) ch[x][1]=i,dis[x][1]=y;
else ch[x][0]=i,dis[x][0]=y;
}
ll l=0,r=1000000000000LL;
while (l<=r)
{
for (int i=1;i<=n;i++) f[i].clear();
mid=(l+r)/2LL;
dfs(1);
if (f[1].size()) r=mid-1;
else l=mid+1;
}
printf("%lld",r+1);
return 0;
}