zoukankan      html  css  js  c++  java
  • 暴力写挂

    题目描述

    题解

    考虑把式子化一下,因为只有一个式子跟第二棵树有关,所以我们可以考虑把前面的式子化成跟 $ ext{lca}$ 没有关系,即 $frac{1}{2}(dp_u+dp_v+dis(u,v))$ 。因此我们可以利用边分治,每次把两边的点黑白染色,构成虚树,然后做 $ ext{dp}$ 即可。这里要注意 $ ext{lca}$ 要 $O(1)$ 求,虚树构成过程中不能排序,故我们可以一开始就按照第二棵树的dfs排序好,之后分治下去即可。效率: $O(nlogn)$ 。

    代码

    #include <bits/stdc++.h>
    #define LL long long
    using namespace std;
    const int N=4e5+5,N2=N<<1,N4=N<<2;
    int n,m,t=1,fa[22][N2],Lg[N2],d[N],e[N],o,rt,su,hd[N2],sz[N2];
    int V[N4],W[N4],nx[N4],b[N],id[N],col[N],tp,S[N],h[2][N],c;
    LL dp[N],Dp[N],sm[N],f[2][N],ans=-2e18;
    bool vis[N2];vector<int>X[N],Y[N];
    void Add(int u,int v,int w){
        X[u].push_back(v);Y[u].push_back(w);
    }
    void add(int u,int v,int w){
        nx[++t]=hd[u];V[hd[u]=t]=v;W[t]=w;
    }
    void add(int u,int v){X[u].push_back(v);}
    void rebuild(int u,int fr){
        int x=0,z=X[u].size();
        for (int v,w,i=0;i<z;i++){
            v=X[u][i];w=Y[u][i];
            if (v==fr) continue;dp[v]=dp[u]+w;
            if (!x) add(u,v,w),add(v,u,w),x=u;
            else m++,add(x,m,0),add(m,x,0),
                add(m,v,w),add(v,m,w),x=m;
            rebuild(v,u);
        }
    }
    void dfs(int u,int fr){
        int z=X[u].size();
        fa[0][e[u]=++c]=u;b[id[u]=++t]=u;
        for (int v,w,i=0;i<z;i++){
            v=X[u][i];w=Y[u][i];
            if (v==fr) continue;
            Dp[v]=Dp[u]+w;d[v]=d[u]+1;
            dfs(v,u);fa[0][++c]=u;
        }
    }
    int Min(int u,int v){return d[u]<d[v]?u:v;}
    int qry(int l,int r){
        l=e[l];r=e[r];
        if (l>r) swap(l,r);int i=Lg[r-l+1];
        return Min(fa[i][l],fa[i][r-(1<<i)+1]);
    }
    void Sz(int u,int fr){
        sz[u]=1;
        for (int v,i=hd[u];i;i=nx[i])
            if (!vis[i>>1] && (v=V[i])!=fr)
                Sz(v,u),sz[u]+=sz[v];
    }
    void Rt(int u,int fr){
        for (int v,w,i=hd[u];i;i=nx[i])
            if (!vis[i>>1] && (v=V[i])!=fr){
                w=max(sz[v],o-sz[v]);
                if (w<su) rt=i,su=w;Rt(v,u);
            }
    }
    void find(int u,int fr,LL w,int cl){
        if (u<=n) col[u]=cl,sm[u]=w;
        for (int v,i=hd[u];i;i=nx[i])
            if (!vis[i>>1] && (v=V[i])!=fr)
                find(V[i],u,w+W[i],cl);
    }
    void ins(int u){
        if (tp<1){S[++tp]=u;return;}
        int x=qry(S[tp],u);
        if (x==S[tp]){S[++tp]=u;return;}
        while(tp>1 && id[S[tp-1]]>=id[x])
            add(S[tp-1],S[tp]),tp--;
        if (x!=S[tp]) add(x,S[tp]),S[tp]=x;
        S[++tp]=u;
    }
    void get(int u){
        int z=X[u].size();
        f[0][u]=f[1][u]=-2e18;
        if (~col[u]) f[col[u]][u]=dp[u]+sm[u];
        for (int v,i=0;i<z;i++){
            v=X[u][i];get(v);
            for (int j=0;j<2;j++)
                ans=max(ans,f[j][u]+f[!j][v]-2ll*Dp[u]);
            for (int j=0;j<2;j++)
                f[j][u]=max(f[j][u],f[j][v]);
        }
        X[u].clear();
    }
    void solve(int u,int l,int r){
        Sz(u,0);o=sz[u];rt=0;su=1e9;
        Rt(u,0);if (!rt) return;
        int x=V[rt],y=V[rt^1];vis[rt>>1]=1;
        find(x,y,0,0);find(y,x,W[rt],1);
        if (b[l]!=1) ins(1);
        for (int i=l;i<=r;i++) ins(b[i]);
        while(tp>1) add(S[tp-1],S[tp]),tp--;
        tp=0;get(1);int v[2]={0};
        for (int w,i=l;i<=r;i++)
            w=col[b[i]],h[w][++v[w]]=b[i],col[b[i]]=-1;
        for (int i=0;i<v[0];i++) b[i+l]=h[0][i+1];
        for (int i=0;i<v[1];i++) b[r-i]=h[1][v[1]-i];
        solve(x,l,l+v[0]-1);solve(y,r-v[1]+1,r);
    }
    int main(){
        cin>>n;m=n;
        for (int i=1,x,y,z;i<n;i++)
            scanf("%d%d%d",&x,&y,&z),
            Add(x,y,z),Add(y,x,z);rebuild(1,0);
        for (int i=1;i<=n;i++)
            X[i].clear(),Y[i].clear(),col[i]=-1;t=0;
        for (int i=1,x,y,z;i<n;i++)
            scanf("%d%d%d",&x,&y,&z),
            Add(x,y,z),Add(y,x,z);dfs(1,0);
        for (int i=2;i<=c;i++) Lg[i]=Lg[i>>1]+1;
        for (int i=c;i;i--)
            for (int j=1;i+(1<<j)<=c+1;j++)
                fa[j][i]=Min(fa[j-1][i],fa[j-1][i+(1<<(j-1))]);
        for (int i=1;i<=n;i++) X[i].clear();
        solve(1,1,n);ans>>=1;
        for (int i=1;i<=n;i++)
            ans=max(ans,dp[i]-Dp[i]);
        cout<<ans<<endl;return 0;
    }
  • 相关阅读:
    第三章例3-3
    第三章例3-2
    第二章例2-11
    第二章例2-10
    第二章例2-9
    204
    205
    202
    203
    201
  • 原文地址:https://www.cnblogs.com/xjqxjq/p/12368634.html
Copyright © 2011-2022 走看看