测试地址:Free Tour II
题目大意:给定一棵树,边有边权,点分为黑点和白点,求经过不超过个黑点的路径的最大边权和。
做法:本题需要用到点分治+桶排序。
首先看到求这样的路径问题,如果DP看上去非常难做,就考虑用点分治。
对于点分治中每个块的根节点,我们考虑求经过这个根节点的合法路径的最大边权和。对于块中的每个点我们得到了一个数对,其中为从根到该点路径上经过的黑点数(为了方便,这里不包括根节点),为路径和。那么我们维护一个数组,其中为经过不超过个点的路径的最大边权和,并且路径一端在根节点,另一端在已经处理过的点中。然后对于正在处理的块中的每一个点,我们实际上就是要求,在的条件下,的最大值,这个我们就可以用两个指针求出了。
然而注意到,上述算法的复杂度最快也是的,再看看数据范围和这个OJ的名字……没办法,优化吧。
优化的方法是,我们把要处理的块按照大小从小到大排序,显然经过的黑点数量不超过,那么我们把各种排序用桶排序解决,那么每一块的时间复杂度就是的,又因为我们已经将块按照大小从小到大排序,所以这个步骤的总时间复杂度就是,这样我们就把算法的时间复杂度降为了,可以通过此题。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,k,m,tot=0,first[200010]={0};
int siz[200010],mxson[200010],q[200010],top;
int srt[200010]={0},nex[200010]={0},black[200010]={0};
ll ww[200010]={0},nowblock[200010]={0},past[200010]={0},ans=0;
bool vis[200010]={0};
struct edge
{
int v,next;
ll w;
}e[400010];
void insert(int a,int b,ll w)
{
e[++tot].v=b;
e[tot].w=w;
e[tot].next=first[a];
first[a]=tot;
}
void dp(int v,int f)
{
q[++top]=v;
siz[v]=1,mxson[v]=0;
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=f&&!vis[e[i].v])
{
dp(e[i].v,v);
mxson[v]=max(mxson[v],siz[e[i].v]);
siz[v]+=siz[e[i].v];
}
}
int find(int v)
{
int ans=1000000000,ansp=0;
top=0;
dp(v,0);
for(int i=1;i<=top;i++)
{
if (max(siz[v]-siz[q[i]],mxson[q[i]])<ans)
ans=max(siz[v]-siz[q[i]],mxson[q[i]]),ansp=q[i];
}
return ansp;
}
void getdown(int v,int f,int passed,ll x)
{
nowblock[passed]=max(nowblock[passed],x);
for(int i=first[v];i;i=e[i].next)
if (e[i].v!=f&&!vis[e[i].v])
getdown(e[i].v,v,passed+black[e[i].v],x+e[i].w);
}
int solve(int v)
{
v=find(v);
vis[v]=1;
int blocksiz=1,pastmxsiz=0;
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v])
{
siz[e[i].v]=solve(e[i].v);
ww[e[i].v]=e[i].w;
}
for(int i=first[v];i;i=e[i].next)
if (!vis[e[i].v])
{
nex[e[i].v]=srt[siz[e[i].v]];
srt[siz[e[i].v]]=e[i].v;
blocksiz+=siz[e[i].v];
}
for(int i=0;i<=blocksiz;i++)
while (srt[i])
{
getdown(srt[i],0,black[srt[i]],ww[srt[i]]);
for(int j=0;j<=i;j++)
{
if (j>0) nowblock[j]=max(nowblock[j-1],nowblock[j]);
if (k-j-black[v]>pastmxsiz) ans=max(ans,nowblock[j]+past[pastmxsiz]);
else if (k-j-black[v]>=0) ans=max(ans,nowblock[j]+past[k-j-black[v]]);
}
for(int j=0;j<=i;j++)
{
past[j]=max(past[j],nowblock[j]);
if (j>0) past[j]=max(past[j],past[j-1]);
nowblock[j]=0;
}
pastmxsiz=i;
srt[i]=nex[srt[i]];
}
for(int i=0;i<=blocksiz;i++)
srt[i]=past[i]=0;
vis[v]=0;
return blocksiz;
}
int main()
{
scanf("%d%d%d",&n,&k,&m);
for(int i=1;i<=m;i++)
{
int x;
scanf("%d",&x);
black[x]=1;
}
for(int i=1;i<n;i++)
{
int a,b;ll w;
scanf("%d%d%lld",&a,&b,&w);
insert(a,b,w),insert(b,a,w);
}
solve(1);
printf("%lld",ans);
return 0;
}