突然发现上题的写法有点问题啊。。。。
仿佛应该是枚举这多的这条边的两个端点分别是什么状态。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<queue> #define maxv 1000050 #define inf 0x7f7f7f7f7f7f7f7fLL using namespace std; long long n,a[maxv],dx[maxv],dy[maxv],dp[maxv][3],cnt=0,father[maxv],ans=0,w[maxv]; struct pnt { long long id,father,rank; }p[maxv]; struct edge { long long x,y,father,root; }e[maxv]; queue <long long> q; long long read() { char ch;long long data=0; while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9') { data=data*10+ch-'0'; ch=getchar(); } return data; } bool cmp1(pnt x,pnt y) { if (x.father!=y.father) return x.father<y.father; return x.rank<y.rank; } bool cmp2(edge x,edge y) {return x.father<y.father;} long long getfather(long long x) { if (father[x]==x) return x; father[x]=getfather(father[x]);return father[x]; } void topu_sort() { for (long long i=1;i<=n;i++) { if (!dy[i]) q.push(i); p[i].father=getfather(i); } while (!q.empty()) { long long head=q.front();q.pop(); dy[a[head]]--;if (!dy[a[head]]) {p[a[head]].rank=p[head].rank+1;q.push(a[head]);} } sort(p+1,p+n+1,cmp1);sort(e+1,e+cnt+1,cmp2); } void tree_dp(long long l,long long r,long long x,long long type) { for (long long i=l;i<=r;i++) dp[p[i].id][1]=dp[p[i].id][2]=0; long long i; for (i=l;!p[i].rank;i++) { long long v=p[i].id; if (v!=x) {dp[v][1]=0;dp[v][2]=w[v];} else { if (type!=2) {dp[v][1]=0;dp[v][2]=-inf;} else {dp[v][1]=-inf;dp[v][2]=w[v];} } if (v!=a[v]) { dp[a[v]][1]+=max(dp[v][1],dp[v][2]); dp[a[v]][2]+=dp[v][1]; } } for (;i<=r;i++) { long long v=p[i].id; if (v!=x) dp[v][2]+=w[v]; else { if (type!=2) dp[v][2]=-inf; else {dp[v][1]=-inf;dp[v][2]+=w[v];} } if (v!=a[v]) { dp[a[v]][1]+=max(dp[v][1],dp[v][2]); dp[a[v]][2]+=dp[v][1]; } } } int main() { n=read();for (long long i=1;i<=n;i++) {father[i]=i;p[i].id=i;} for (long long i=1;i<=n;i++) { w[i]=read();a[i]=read(); long long f1=getfather(i),f2=getfather(a[i]); if (f1==f2) {e[++cnt].x=i;e[cnt].y=a[i];} else {dx[i]++;dy[a[i]]++;father[f1]=f2;} } for (long long i=1;i<=cnt;i++) e[i].father=getfather(e[i].x); topu_sort(); long long l=1,r=1,ret=0; while (l<=n) { long long mx=0;r=l;while (p[r+1].father==p[l].father) r++; long long i;ret++; tree_dp(l,r,e[ret].y,1);mx=max(mx,dp[e[ret].x][1]); tree_dp(l,r,e[ret].y,2);mx=max(mx,dp[e[ret].x][1]); tree_dp(l,r,e[ret].y,3);mx=max(mx,dp[e[ret].x][2]); ans+=mx;l=r+1; } printf("%lld ",ans); return 0; }