zoukankan      html  css  js  c++  java
  • 【HNOI】 lct tree-dp

      【题目描述】给定2-3颗树,每个边的边权为1,解决以下独立的问题。

        现在通过连接若干遍使得图为连通图,并且Σdis(x,y)最大,x,y只算一次。

        每个点为黑点或者白点,现在需要删除一些边,使得图中的黑点度数为奇数,白点为偶数,要求删除的边最多。

      【数据范围】 100% n<=10^5

      首先我们来解决第一问,因为每加一条边就可能使得若干点到其他点的距离变小,那么我们需要加尽量少的边来使得图连通。

      设dis_[x]为x在x所在子树中,x到其他所有点的距离,这个我们可以通过设dis[x]表示x到x子树中所有点的距离和来由父节点转移得到。

      那么答案可以分为两部分,分别为树中的点对距离和跨树的点对距离,前一个问题比较容易,可以通过dis_[x]或者计算每条边被经过的次数来求出。

      那么对于两颗树的情况,我们就需要连接这两棵树中dis值最大的两个点,假设为x,y。这样答案就是      dis[x]*size[tree_y]+dis[y]*size[tree_x]+size[tree_x]*size[tree_y],这个由连接的那条边的被经过次数可以得出。

      那么现在考虑三棵树的情况,我们需要枚举中间的树,这样左右两棵树肯定连接dis最大的点,中间的连接的则不确定,我们可以列出来整个答案的表达式,设左面的树和中间的树通过x,y点连通,中间的点和右面的树通过u,v点连接,设三棵树的size为size[1],size[2],size[3],y与u点的距离为d[y][u],那么答案就是size[1]*dis_[y]+size[2]*dis_[x]+size[1]*size[2]+size[3]*dis_[u]+size[2]*dis_[v]+size[1]*dis_[v]+size[3]*dis_[u]+size[1]*size[3]*(d[y][u]+2)

      我们可以发现,这个中与y,u点有关的式子可以写成a*dis_[y]+b*dis_[u]+c*d[u][v]的形式,其中a,b,c为常数,那么对于这个我们就可以用tree-dp搞出来,记录点x的子树中dis_[p]+(d[x][p]+1)*c的最大值,然后不断的更新答案就可以了。

      第二问比较简单,我们可以贪心的来想,对于一棵树,我们从叶子节点开始,因为叶子节点的度数为1,那么我们只需要判断叶子节点的颜色,就可以判断这个点和其父节点的边是否可以删掉。

      反思:开始写tree-dp维护中间树的值的时候没有考虑到一些特殊情况,比如连接的y,u点其中一点是另一点的祖先,还有开始觉得如果中间的树选择两个点肯定不能是同一点,所以边界就处理的不是特别好,但是可能会有某些点单独构成树,这样的话就必须连接同一个点。第二问还是比较容易写的。

    //By BLADEVIL
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #define maxn 100010
    #define LL long long
    
    using namespace std;
    
    LL n,m,l;
    LL a[maxn],pre[maxn<<1],other[maxn<<1],last[maxn],col[maxn],rot[4],num[maxn<<1],flag[maxn];
    LL dis_[maxn],dis[maxn],size[maxn],ans[maxn],ANS[4],max_a[maxn],max_b[maxn],cnt[maxn];
    
    void connect(LL x,LL y,LL z) {
        pre[++l]=last[x];
        last[x]=l;
        other[l]=y;
        num[l]=z;
    }
    
    void paint(LL x,LL fa,LL c) {
        col[x]=c;
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            paint(other[p],x,c);
        }
    }
    
    void make_dis(LL x,LL fa) {
        dis[x]=0; size[x]=1;
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            make_dis(other[p],x);
            dis[x]+=dis[other[p]]+size[other[p]];
            size[x]+=size[other[p]];
        }
    }
    
    void make_dis_(LL x,LL fa,LL s) {
        if (fa!=-1) dis_[x]=dis_[fa]-size[x]-dis[x]+s-size[x]+dis[x]; else dis_[x]=dis[x];
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            make_dis_(other[p],x,s);
        }
    }
    
    void calc(LL x,LL fa,LL s) {
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            ANS[col[x]]+=size[other[p]]*(s-size[other[p]]);
            calc(other[p],x,s);
        }
    }
    
    void dp(LL x,LL fa,LL a,LL b,LL c,LL &Ans) {
        max_a[x]=dis_[x]*a+c; max_b[x]=dis_[x]*b+c;
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            dp(other[p],x,a,b,c,Ans);
            max_a[x]=max(max_a[x],max_a[other[p]]+c);
            max_b[x]=max(max_b[x],max_b[other[p]]+c);
        }
        LL aa=0,bb=0;
        //printf("%d %d
    ",max_a[x],max_b[x]);
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            Ans=max(Ans,max_a[x]+c+dis_[x]*b);
            Ans=max(Ans,max_b[x]+c+dis_[x]*a);
        }
        //printf("%d %d
    ",x,Ans);
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            if (max_a[other[p]]>max_a[aa]) aa=other[p];
            if (max_b[other[p]]>max_b[bb]) bb=other[p];
        }
        //printf("%d %d
    ",x,Ans);
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            if (other[p]!=aa) Ans=max(Ans,max_a[aa]+max_b[other[p]]+(c<<1));
            //printf("%d %d %d %d
    ",Ans,max_a[aa],max_b[other[p]],c<<1);
            if (other[p]!=bb) Ans=max(Ans,max_b[bb]+max_a[other[p]]+(c<<1));
        }
        //printf("%d %d
    ",aa,max_a[aa]);
        //printf("%d %d
    ",x,Ans);
    }
    
    LL work(LL le,LL x,LL ri) {
        LL a=size[le],b=size[ri],c=size[le]*size[ri],ans=0;
        LL cur[4]; cur[1]=cur[2]=cur[3]=0;
        for (LL i=1;i<=n;i++) cur[col[i]]=max(cur[col[i]],dis_[i]);
        //printf("fuck %d %d
    ",col[le],col[ri]);
        ans=cur[col[le]]*size[x]+a*size[x]+cur[col[ri]]*size[x]+size[x]*b+a*cur[col[ri]]+b*cur[col[le]];
        //printf("fuck
    ");
        memset(max_a,0,sizeof max_a);
        memset(max_b,0,sizeof max_b);
        LL Ans=-1;
        dp(x,-1,a,b,c,Ans);
        Ans=max(Ans,c<<1);
        //printf("%d %d
    ",ans,Ans);
        ans+=Ans;
        //printf("%d
    ",ans);
        return ans;
    }
    
    void Work(LL x,LL fa) {
        for (LL p=last[x];p;p=pre[p]) {
            if (other[p]==fa) continue;
            Work(other[p],x);
        }
        //printf("%d %d %d
    ",x,a[x],cnt[x]);
        if (a[x]^cnt[x]) {
            for (LL p=last[x];p;p=pre[p]) 
                if (other[p]==fa) flag[num[p]]=1;
            cnt[fa]^=1;
        };
    }
    
    int main() {
        freopen("lct.in","r",stdin); freopen("lct.out","w",stdout);
        scanf("%lld%lld
    ",&n,&m);
        char c;
        for (LL i=1;i<=n;i++) scanf("%c",&c),a[i]=(c=='B')?1:0;
        for (LL i=1;i<=m;i++) {
            LL x,y;
            scanf("%lld%lld",&x,&y);
            connect(x,y,i); connect(y,x,i);
        }
        LL sum=0;
        for (LL i=1;i<=n;i++) if (!col[i]) paint(i,-1,++sum),rot[sum]=i;
        for (LL i=1;i<=3;i++) if (rot[i]) make_dis(rot[i],-1),make_dis_(rot[i],-1,size[rot[i]]);
        for (LL i=1;i<=3;i++) if (rot[i]) calc(rot[i],-1,size[rot[i]]);
        //for (LL i=1;i<=n;i++) printf("%d ",col[i]); printf("
    ");
        //printf("%d %d %d
    ",rot[1],rot[2],rot[3]);
        //for (LL i=1;i<=n;i++) printf("%d %d %d %d
    ",i,dis[i],dis_[i],size[i]);
        //for (LL i=1;i<=3;i++) printf("%d ",ANS[i]); printf("
    ");
        if (sum==2) {
            LL cur[3];
            cur[1]=cur[2]=0;
            for (LL i=1;i<=n;i++) cur[col[i]]=max(cur[col[i]],dis_[i]);
            LL Ans=ANS[1]+ANS[2]+cur[1]*size[rot[2]]+cur[2]*size[rot[1]]+size[rot[1]]*size[rot[2]];
            printf("%lld
    ",Ans);
            //printf("%d %d
    ",cur[1],cur[2]);
            //printf("%d %d
    ",size[rot[1]],size[rot[2]]);
            //printf("%d %d
    ",ANS[1],ANS[2]);
        } else {
            LL Ans=0;
            Ans=max(Ans,work(rot[2],rot[1],rot[3]));
            Ans=max(Ans,work(rot[1],rot[2],rot[3]));
            Ans=max(Ans,work(rot[1],rot[3],rot[2]));
            //printf("%d
    ",Ans);
            Ans+=ANS[1]+ANS[2]+ANS[3];
            printf("%lld
    ",Ans);
        }
        for (LL i=1;i<=n;i++) 
            for (LL p=last[i];p;p=pre[p]) cnt[other[p]]++,cnt[i]++;
        //for (LL i=1;i<=n;i++) printf("%d
    ",cnt[i]);
        //for (LL i=1;i<=n;i++) printf("%d
    ",a[i]);
        for (LL i=1;i<=n;i++) cnt[i]/=2,cnt[i]%=2;
        for (LL i=1;i<=3;i++) Work(rot[i],-1);
        LL ans_=0;
        for (LL i=1;i<=m;i++) if (!flag[i]) ans_++;
        printf("%lld
    ",ans_);
        for (LL i=1;i<=m;i++) if (!flag[i]) printf("%lld ",i); printf("
    ");
        fclose(stdin); fclose(stdout);
        return 0;
    }
  • 相关阅读:
    javascript阻止子元素继承父元素事件
    UTC 时间转化为北京时间
    uniapp中引入less文件
    HDU 1002 A + B Problem II(大数据)
    FatMouse's Speed(dp)
    Monkey and Banana(dp)
    Piggy-Bank(dp,背包)
    Longest Ordered Subsequence(最长上升子序列,dp)
    我的第一篇博客
    redis优化方案
  • 原文地址:https://www.cnblogs.com/BLADEVIL/p/3625012.html
Copyright © 2011-2022 走看看