基环树DP。
题目大意:给定一张$n$个点$n$条边的无向图,求所有连通块的直径长度之和。
----------------------
考虑到边数与点数相等,即一个连通块内不可能有两个环。考虑基环树DP。现在先把连通块内的环看成一个点,对于直径,有如下两种情况:
1.直径在环上点的子树中
2.直径横跨环,端点分别在两个子树中
对于情况1,我们只需进行树形DP即可。对于情况2,因为是在环上,所以环上两点之间的简单路径有两条,所以我们不妨破换成链。设以$i$为根的子树中最大距离为$d_i$,对于一个连通块,我们显然想要$d_i+d_j+s_{i,j}$的值最大($s$指环上$i$到$j$的距离)。$s_{i,j}$可以通过前缀和相减得到,这样就转化成为$d_i+s_i+d_j-s_j$最大。于是可以用单调队列优化。
细节有点小多。时间复杂度$O(n)$。
代码:
#include<cstdio> #include<iostream> #include<queue> #define int long long using namespace std; const int maxn=1000005; int n,tot,cnt,st,ans,ans2,ans3,anss; int s[maxn],v[maxn],v2[maxn],dfn[maxn],d[maxn],dp[maxn*2]; int head[maxn]; struct node { int next,to,dis; }edge[maxn*2]; inline int read() { int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();} while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();} return x*f; } inline void add(int from,int to,int dis) { edge[++cnt]=(node){head[from],to,dis}; head[from]=cnt; } inline bool dfs(int now,int pre) { if (v[now]==1) { v[now]=2,v2[now]=1;dfn[++tot]=now; return 1; } v[now]=1; for (int i=head[now];i;i=edge[i].next) { if (i!=((pre-1)^1)+1&&dfs(edge[i].to,i)) { if (v[now]!=2) { dfn[++tot]=now,v2[now]=1; s[tot]=s[tot-1]+edge[i].dis; } else { s[st-1]=s[st]-edge[i].dis; return 0; } return 1; } } return 0; } inline void dfs_dp(int now) { v2[now]=1; for (int i=head[now];i;i=edge[i].next) { int to=edge[i].to; if (v2[to]) continue; dfs_dp(to); ans=max(ans,d[now]+d[to]+edge[i].dis); d[now]=max(d[now],d[to]+edge[i].dis); } } inline int solve(int root) { st=tot+1,ans2=0,ans3=0; dfs(root,0); for (int i=st;i<=tot;i++) { ans=0; dfs_dp(dfn[i]); ans2=max(ans2,ans); dp[i+tot-st+1]=dp[i]=d[dfn[i]]; s[i+tot-st+1]=s[i+tot-st]+s[i]-s[i-1]; } deque<int> q; for (int i=st;i<=2*tot-st+1;i++) { while(q.size()&&q.front()<=i-tot+st-1) q.pop_front(); if (q.size()) ans3=max(ans3,dp[i]+dp[q.front()]+s[i]-s[q.front()]); while(q.size()&&dp[q.back()]-s[q.back()]<=dp[i]-s[i]) q.pop_back(); q.push_back(i); } return max(ans2,ans3); } signed main() { n=read(); for (int i=1;i<=n;i++) { int x=read(),y=read(); add(i,x,y);add(x,i,y); } for (int i=1;i<=n;i++) if (!v2[i]) anss+=solve(i); printf("%lld",anss); return 0; }