这道题我们加一条路可以减少的代价为这条路两端点到lca的路径的长度,相当于一条链,那么如果加了两条链的话,这两条链重复的部分还是要走两遍,反而对答案没有了贡献(其实这个可以由任意两条链都可以看成两条不重叠的链来证明),那么这道题k=2的时候就转化为了求出树上两条链,使得两条链不重叠的长度最大,那么答案就是(n-1)<<1-SumLen+2.当k=1的时候我们直接求出来树的最长链然后减去就好了,这个在此不再赘述。
对于树上两链不重复部分最大我们是可以tree_dp的,设w[i][0..4]来表示当前以i为根的子树中选取了0/1/2条链的最大值,同时我们保留了一个3,4来记录以i为一端点的最长链,同时选取了0/1条最长链的最大值,这样直接转移就好了。
我写的是另外一种方法,先找出最长链,然后将最长链上的边长设为-1,然后再找一次最长链,这样求出来的就是答案。
反思:开始没意识到第二次最长链不能用两边bfs,所以果断的写了bfs,后来才发现的,又临时加了一个tree_dp,因为加的路必须选,所以我们要将每个点的最长和次长链设为-inf,叶子节点的为0,然后用非叶子节点更新答案,然后竟然1A,真是感动= =。
/************************************************************** Problem: 1912 User: BLADEVIL Language: C++ Result: Accepted Time:1268 ms Memory:5884 kb ****************************************************************/ //By BLADEVIL #include <cstdio> #include <cstring> #include <algorithm> #define maxn 100010 #define maxm 200020 #define inf (~0U>>1) using namespace std; int n,k,l; int pre[maxm],other[maxm],last[maxn],len[maxm]; int que[maxn],dis[maxn],father[maxn],flag[maxn],max_1[maxn],max_2[maxn]; void connect(int x,int y) { pre[++l]=last[x]; last[x]=l; other[l]=y; len[l]=1; } void bfs(int x) { memset(que,0,sizeof que); memset(dis,0,sizeof dis); memset(father,0,sizeof father); memset(flag,0,sizeof flag); int h=0,t=1; que[1]=x; dis[x]=1; flag[x]=1; while (h<t) { int cur=que[++h]; for (int p=last[cur];p;p=pre[p]) { if (flag[other[p]]) continue; father[other[p]]=p; dis[other[p]]=dis[cur]+len[p]; flag[other[p]]=1; que[++t]=other[p]; } } } int tree_dp() { int ans=-inf; memset(que,0,sizeof que); memset(flag,0,sizeof flag); memset(dis,0,sizeof dis); memset(max_1,-128,sizeof max_1); memset(max_2,-128,sizeof max_2); int h=0,t=1; que[1]=1; flag[1]=1; dis[1]=1; while (h<t) { int cur=que[++h]; for (int p=last[cur];p;p=pre[p]) { if (flag[other[p]]) continue; que[++t]=other[p]; flag[other[p]]=1; dis[other[p]]=dis[cur]+1; } } //for (int i=1;i<=n;i++) printf("%d ",que[i]); printf(" "); for (int i=n;i;i--) { int cur=que[i]; for (int p=last[cur];p;p=pre[p]) { if (dis[other[p]]<dis[cur]) continue; if (max_1[other[p]]+len[p]>max_1[cur]) max_2[cur]=max_1[cur],max_1[cur]=max_1[other[p]]+len[p]; else if (max_1[other[p]]+len[p]>max_2[cur]) max_2[cur]=max_1[other[p]]+len[p]; } if (max_1[cur]<-100000000) max_1[cur]=max_2[cur]=0; else ans=max(ans,max(max_1[cur]+max_2[cur],max_1[cur])); } //for (int i=1;i<=n;i++) printf("|%d %d ",max_1[i],max_2[i]); return ans; } int getmax() { int s=0; for (int i=1;i<=n;i++) if (dis[i]>dis[s]) s=i; return s; } int main() { scanf("%d%d",&n,&k); l=1; for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); connect(x,y); connect(y,x); } bfs(1); bfs(getmax()); if (k==1) { printf("%d ",2*n-dis[getmax()]); return 0; } int cur=getmax(),ans=dis[cur]-2; while (father[cur]) len[father[cur]]=len[father[cur]^1]=-1,cur=other[father[cur]^1]; ans+=tree_dp()-1; printf("%d ",2*n-2-ans); return 0; }