设f[i]为由i开始遍历完子树内所要求的点的最短时间,g[i]为由i开始遍历完子树内所要求的点最后回到i的最短时间。则g[i]=Σ(g[j]+2),f[i]=min{g[i]-g[j]+f[j]-1}。
然后由父亲答案还原。因为上面的dp用到了max似乎不太好搞,于是记录一下最大值是用了哪棵子树以及次大值就行了。
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<algorithm> using namespace std; int read() { int x=0,f=1;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } #define N 500010 #define ll long long int n,m,a[N],id[N],p[N],size[N],t=0; ll f[N],f2[N],g[N]; bool flag[N]; struct data{int to,nxt,len; }edge[N<<1]; void addedge(int x,int y,int z){t++;edge[t].to=y,edge[t].nxt=p[x],edge[t].len=z,p[x]=t;} inline ll noback(int x,int y,int z){return g[x]-g[y]+f[y]-z;} void dfs(int k,int from) { size[k]=flag[k]; for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=from) { dfs(edge[i].to,k); size[k]+=size[edge[i].to]; if (size[edge[i].to]) g[k]+=g[edge[i].to]+(edge[i].len<<1); } f[k]=f2[k]=g[k]; for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=from&&size[edge[i].to]) { ll x=noback(k,edge[i].to,edge[i].len); if (x<f[k]) f2[k]=f[k],f[k]=x,id[k]=edge[i].to; else if (x<f2[k]) f2[k]=x; } } void getans(int k,int from) { for (int i=p[k];i;i=edge[i].nxt) if (edge[i].to!=from) { if (size[edge[i].to]) { if (size[edge[i].to]<m) { ll x=f[k],y=g[k]; g[k]-=g[edge[i].to]+(edge[i].len<<1); if (id[k]==edge[i].to) f[k]=f2[k]-(g[edge[i].to]+(edge[i].len<<1)); else f[k]=f[k]-(g[edge[i].to]+(edge[i].len<<1)); g[edge[i].to]=y; if (noback(edge[i].to,k,edge[i].len)<f[edge[i].to]+g[k]+(edge[i].len<<1)) f2[edge[i].to]=f[edge[i].to]+g[k]+(edge[i].len<<1),id[edge[i].to]=k,f[edge[i].to]=noback(edge[i].to,k,edge[i].len); else f[edge[i].to]+=g[k]+(edge[i].len<<1),f2[edge[i].to]=min(f2[edge[i].to]+g[k]+(edge[i].len<<1),noback(edge[i].to,k,edge[i].len)); f[k]=x,g[k]=y; } } else g[edge[i].to]=g[k]+(edge[i].len<<1),f[edge[i].to]=f[k]+edge[i].len; getans(edge[i].to,k); } } int main() { #ifndef ONLINE_JUDGE freopen("bzoj3743.in","r",stdin); freopen("bzoj3743.out","w",stdout); const char LL[]="%I64d "; #else const char LL[]="%lld "; #endif n=read(),m=read(); for (int i=1;i<n;i++) { int x=read(),y=read(),z=read(); addedge(x,y,z),addedge(y,x,z); } for (int i=1;i<=m;i++) flag[read()]=1; dfs(1,1); getans(1,1); for (int i=1;i<=n;i++) printf(LL,f[i]); return 0; }