可以发现,每个特殊点可以贡献的部分在树上是一条链。
设三元组(v,x,y)表示路径长度,需要更新的端点,与当前点的lca为y。
对于每个节点x,通过两遍树形DP可以求出:
d[x]:x到x子树内的某个特殊点的最优解。
u[x]:x到x子树外的某个特殊点的最优解。
pre[x]:x以及x之前的兄弟的d[]的最优解。
suf[x]:x以及x之后的兄弟的d[]的最优解。
然后在树上打标记,最后dfs一遍统计答案即可。
时间复杂度$O(n)$。
#include<cstdio> #define N 100010 int n,m,i,x,y,z,vip[N],ans1,ans2; int g[N],v[N<<1],w[N<<1],nxt[N<<1],ed,dis[N],q[N],t,tag[N],cnt[N]; struct P{ int v,x,y; P(){v=-1;} P(int _v,int _x,int _y){v=_v,x=_x,y=_y;} }d[N],u[N],pre[N],suf[N],fin[N]; inline void read(int&a){char c;while(!(((c=getchar())>='0')&&(c<='9')));a=c-'0';while(((c=getchar())>='0')&&(c<='9'))(a*=10)+=c-'0';} inline void add(int x,int y,int z){v[++ed]=y;w[ed]=z;nxt[ed]=g[x];g[x]=ed;} void dfs1(int x,int f){ if(vip[x])d[x]=P(0,x,x); for(int i=g[x];i;i=nxt[i])if(v[i]!=f){ int y=v[i];dis[y]=w[i]; dfs1(y,x); if(d[y].v<0)continue; if(d[y].v+w[i]>d[x].v)d[x]=d[y],d[x].v+=w[i]; else if(d[y].v+w[i]==d[x].v)d[x].x=x; } d[x].y=x; fin[x]=d[x]; } void dfs2(int x,int f){ if(vip[x]&&u[x].v<0)u[x]=P(0,x,x); t=0; for(int i=g[x];i;i=nxt[i])if(v[i]!=f)q[++t]=v[i]; for(int i=1;i<=t;i++){ int y=q[i]; pre[i]=pre[i-1]; if(d[y].v<0)continue; if(d[y].v+dis[y]>pre[i].v)pre[i]=d[y],pre[i].v+=dis[y]; else if(d[y].v+dis[y]>pre[i].v)pre[i].x=x; } suf[t+1]=P(); for(int i=t;i;i--){ int y=q[i]; suf[i]=suf[i+1]; if(d[y].v<0)continue; if(d[y].v+dis[y]>suf[i].v)suf[i]=d[y],suf[i].v+=dis[y]; else if(d[y].v+dis[y]>suf[i].v)suf[i].x=x; } for(int i=1;i<=t;i++){ int y=q[i]; P B=pre[i-1]; if(B.v<suf[i+1].v)B=suf[i+1]; else if(B.v==suf[i+1].v)B.x=x; B.y=x; if(B.v<u[x].v)B=u[x]; else if(B.v==u[x].v)B.x=x; if(~B.v)B.v+=dis[y]; u[y]=B; if(!vip[y])continue; if(B.v>fin[y].v)fin[y]=B; else if(B.v==fin[y].v)fin[y].v=-1; } for(int i=g[x];i;i=nxt[i])if(v[i]!=f)dfs2(v[i],x); } inline void modify(int x,int y,int z){tag[x]++,tag[y]++,tag[z]-=2,cnt[z]++;} void dfs3(int x,int f){ for(int i=g[x];i;i=nxt[i])if(v[i]!=f)dfs3(v[i],x),tag[x]+=tag[v[i]]; cnt[x]+=tag[x]; } int main(){ read(n),read(m); for(i=1;i<=m;i++)read(x),vip[x]=1; for(i=1;i<n;i++)read(x),read(y),read(z),add(x,y,z),add(y,x,z); dfs1(1,0),dfs2(1,0); for(i=1;i<=n;i++)if(vip[i])if(~fin[i].v)modify(i,fin[i].x,fin[i].y); dfs3(1,0); for(i=1;i<=n;i++)if(!vip[i]){ if(cnt[i]>ans1)ans1=cnt[i],ans2=1; else if(cnt[i]==ans1)ans2++; } return printf("%d %d",ans1,ans2),0; }