原题
给出一颗有n个点的树,其中有M个点是拥挤的,请选出一条最多包含k个拥挤的点的路径使得经过的权值和最大。
正常树分治,每次处理路径,更新答案。
计算每棵子树的deep(本题以经过拥挤节点个数作为deep),然后记录mx[i]为当前为止经过i个拥挤节点所达到的最大价值,tmp[i]为当前所在树中经过i个拥挤节点所达到的最大价值,用于更新答案即可。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#define N 200010
using namespace std;
int ans,n,K,m,cnt,head[N],f[N];
vector < pair<int,int> > v;
struct hhh
{
int to,next,w;
}edge[2*N];
int read()
{
int ans=0,fu=1;
char j=getchar();
for (;j<'0' || j>'9';j=getchar()) if (j=='-') fu=-1;
for (;j>='0' && j<='9';j=getchar()) ans*=10,ans+=j-'0';
return ans*fu;
}
void add(int u,int v,int w)
{
edge[cnt].to=v;edge[cnt].next=head[u];edge[cnt].w=w;head[u]=cnt++;
edge[cnt].to=u;edge[cnt].next=head[v];edge[cnt].w=w;head[v]=cnt++;
}
void getroot(int x,int fa)
{
sze[x]=1;
son[x]=0;
for (int i=head[x];i;i=edge[i].next)
if (!vis[edge[i].to] && edge[i].to!=fa)
{
getroot(edge[i].to,x);
son[x]=max(son[x],sze[edge[i].to]);
sze[x]+=sze[edge[i].to];
}
son[x]=max(son[x],sum-sze[x]);
if (son[x]<son[rt]) rt=x;
}
void getdis(int x,int fa)
{
deep_mx=max(deep_mx,deep[x]);
for (int i=head[x];i;i=edge[i].next)
if (!vis[edge[i].to] && edfe[i].to!=fa)
{
deep[edge[i].to]=deep[x]+color[edge[i].to];
dis[edge[i].to]=dis[x]+edge[i].w;
getdis(edge[i].to,x);
}
}
void getmx(int x,int fa)
{
tmp[deep[x]]=max(tmp[deep[x]],dis[x]);
for (int i=head[x];i;i=edge[i].to)
if (!vis[edge[i].to] && edge[i].to!=fa)
getmx(edge[i].to,x);
}
void solve(int x)
{
vis[x]=1;
v.clear();
for (int i=head[x];i;i=edge[i].next)
if (!vis[edge[i].to])
{
deep_mx=0;
deep[edge[i].to]=color[edge[i].to];
dis[edge[i].to]=edge[i].ww;
getdis(edge[i].to,x);
v.push_back(make_pair(deep_mx,edge[i].to));
}
sort(v.begin(),v.end());
int s=v.size();
for (int i=0;i<s;i++)
{
getmx(st[i].second,x);
int now=0;
if (i!=0)
for (int j=v[i].first;j>=0;j--)
{
while (now+j<K && now<st[i-1].first)
now++,mx[now]=max(mx[now],mx[now-1]);
if (now+j<=K) ans=max(mx[now]+tmp[j]);
}
if (i!=s-1)
for (int j=0;j<=v[i].first;j++)
mx[j]=max(mx[j],tmp[j]),tmp[j]=0;
else
for (int j=0;j<=v[i].first;j++)
{
if (j<=K) ans=max(ans,max(tmp[j],mx[j]));
tmp[j]=mx[j]=0;
}
}
}
int main()
{
n=read();
K=read();
m=read();
for (int i=1;i<=m;i++)
{
int x=read();
color[x]=1;
}
for (int i=1,u,v,w;i<n;i++)
{
u=read();v=read();w=read();
add(u,v,w);
}
sum=n;
f[0]=n;
getroot(1,0);
solve(rt);
printf("%d",ans);
return 0;
}