zoukankan      html  css  js  c++  java
  • Atcoder beginner contest 163 f path pass i

    传送门:https://atcoder.jp/contests/abc163/tasks/abc163_f

    题目大意:一颗n个节点的树,每个节点有一个颜色。求对每一个颜色,至少经过一个该颜色节点的简单路径数量。

      分析:虽然有O(n)的做法,但是这里还是贴一下虚树的做法。虚树的做法大概是:对每一种颜色建立虚树,对于每一个标记好的节点,分别统计其子树的非标记节点联通块大小,总数减去这样的情况。这里是通过子树大小减去子树中标记节点的子树大小来统计的。

    #include<bits/stdc++.h>
    
    #define all(x) x.begin(),x.end()
    #define fi first
    #define sd second
    #define lson (nd<<1)
    #define rson (nd+nd+1)
    #define PB push_back
    #define mid (l+r>>1)
    #define MP make_pair
    #define SZ(x) (int)x.size()
    
    using namespace std;
    
    typedef long long LL;
    
    typedef vector<int> VI;
    
    typedef pair<int,int> PII;
    
    inline int read(){
        int res=0, f=1;char ch=getchar();
        while(ch<'0'|ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9'){res=res*10+ch-'0';ch=getchar();}
        return res*f;
    }
    
    const int MAXN = 200'005;
    
    const int MOD = 1000000007;
    
    void addmod(int& a, int b){a+=b;if(a>=MOD)a-=MOD;}
    int mulmod(int a, int b){return 1ll*a*b%MOD;}
    
    template<typename T>
    void chmin(T& a, T b){if(a>b)a=b;}
    
    template<typename T>
    void chmax(T& a, T b){if(b>a)a=b;}
    
    #define go(e,u) for(int e=head[u];e;e=Next[e])
    int to[MAXN<<1],Next[MAXN<<1],head[MAXN],tol;
    
    void add_edge(int u,int v){
        Next[++tol]=head[u];to[tol]=v;head[u]=tol;
        Next[++tol]=head[v];to[tol]=u;head[v]=tol;
    }
    
    #define gov(e,u) for(int e=headv[u];e;e=Nextv[e])
    int tov[MAXN<<1],Nextv[MAXN<<1],headv[MAXN],tolv;
    
    void add_edgev(int u,int v){
        Nextv[++tolv]=headv[u];tov[tolv]=v;headv[u]=tolv;
    }
    
    int n, col[MAXN];
    
    vector<int> nodes[MAXN];
    
    int dfn[MAXN], R[MAXN], dfncnt;
    int up[MAXN][25], dep[MAXN], st[MAXN], top, sz[MAXN];
    int mark[MAXN];
    
    LL ans;
    
    void dfs(int u, int f){
        sz[u]=1;
        dfn[u]=++dfncnt;
    
        for(int i=0;up[u][i];++i)up[u][i+1]=up[up[u][i]][i];
    
        go(e,u){
            int v=to[e];
            if(v==f)continue;
            up[v][0]=u;
            dep[v]=dep[u]+1;
            dfs(v,u);
            sz[u]+=sz[v];
        }
        R[u]=dfncnt;
    }
    
    int getLCA(int u, int v){
        if(dep[u]<dep[v])swap(u,v);
    
        for(int i=20;i>=0;--i){
            if(dep[up[u][i]]>=dep[v]){
                u=up[u][i];
            }
        }
    
        if(u==v)return u;
    
        for(int i=20;i>=0;--i){
            if(up[u][i]!=up[v][i]){
                u=up[u][i];
                v=up[v][i];
            }
        }
    
        return up[u][0];
    }
    
    bool cmp(int x, int y){
        return dfn[x]<dfn[y];
    }
    
    bool cmp2(PII x, PII y){//未排序,wa
        return x.fi<y.fi;
    }
    
    LL dfs1(int u){
        LL s=0;
    
        vector<PII> num;
        gov(e,u){
            int v=tov[e];
            LL t=dfs1(v);
            s+=t;
            if(mark[u]) num.PB(MP(dfn[v],t));
        }
    
        sort(all(num),cmp2);
    
        if(mark[u]){
            int idx=0;
            go(e,u){
                int v=to[e];
                if(v==up[u][0])continue;
                int cc=0;
                while(idx<SZ(num)&&num[idx].fi>=dfn[v]&&num[idx].fi<=R[v]){
                    cc+=num[idx].sd;
                    ++idx;
                }
    
                ans-=1ll*(sz[v]-cc)*(sz[v]-cc+1)/2;
            }
        }
    
        if(u==1&&!mark[u]){
            LL num=sz[1]-s;
            ans-=1ll*num*(num+1)/2;
        }
    
        if(mark[u])return sz[u];
        else return s;
    }
    
    int main(){
        n=read();
        for(int i=1;i<=n;++i){
            col[i]=read();
            nodes[col[i]].PB(i);
        }
    
        for(int i=1;i<n;++i){
            int u=read(),v=read();
            add_edge(u,v);
        }
    
        dep[1]=1;
        dfs(1,0);
        for(int color=1;color<=n;++color){
            if(!SZ(nodes[color])){
                cout<<0<<endl;
                continue;
            }
    
            ans=1ll*n*(n+1)/2;
            sort(all(nodes[color]),cmp);
    
            //建立虚树
            st[top=1]=1;headv[1]=0;tolv=0;
            for(int i=0;i<SZ(nodes[color]);++i){
                int nn=nodes[color][i];
                mark[nn]=1;
                if(nn==1)continue;
    
                int l=getLCA(st[top],nn);
    
                if(l!=st[top]){
                    while(dfn[l]<dfn[st[top-1]]){
                        add_edgev(st[top-1],st[top]);
                        --top;
                    }
                    if(dfn[l]>dfn[st[top-1]]){
                        headv[l]=0;add_edgev(l,st[top]);st[top]=l;
                    }else{
                        add_edgev(l,st[top--]);
                    }
                }
                headv[nn]=0;st[++top]=nn;
            }
    
            for(int i=1;i<top;++i){
                add_edgev(st[i],st[i+1]);
            }
    
            dfs1(st[1]);
            cout<<ans<<endl;
            for(int i=0;i<SZ(nodes[color]);++i)mark[nodes[color][i]]=0;
        }
    
        return 0;
    }
    View Code

    BTW,O(n)的做法。

    #include<bits/stdc++.h>
    
    #define all(x) x.begin(),x.end()
    #define fi first
    #define sd second
    #define lson (nd<<1)
    #define rson (nd+nd+1)
    #define PB push_back
    #define mid (l+r>>1)
    #define MP make_pair
    #define SZ(x) (int)x.size()
    
    using namespace std;
    
    typedef long long LL;
    
    typedef vector<int> VI;
    
    typedef pair<int,int> PII;
    
    inline LL read(){
        LL res=0, f=1;char ch=getchar();
        while(ch<'0'|ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9'){res=res*10+ch-'0';ch=getchar();}
        return res*f;
    }
    
    const int MAXN = 200'005;
    
    const int MOD = 1000000007;
    
    void addmod(int& a, int b){a+=b;if(a>=MOD)a-=MOD;}
    int mulmod(int a, int b){return 1ll*a*b%MOD;}
    
    template<typename T>
    void chmin(T& a, T b){if(a>b)a=b;}
    
    template<typename T>
    void chmax(T& a, T b){if(b>a)a=b;}
    
    LL n;
    
    LL sz[MAXN], sum[MAXN], ans[MAXN];
    LL col[MAXN];
    
    #define go(e,u) for(int e=head[u];e;e=Next[e])
    int to[MAXN<<1],Next[MAXN<<1],head[MAXN],tol;
    
    void add_edge(int u,int v){
        Next[++tol]=head[u];to[tol]=v;head[u]=tol;
        Next[++tol]=head[v];to[tol]=u;head[v]=tol;
    }
    
    LL calc(LL x){return x*(x+1)/2;}
    
    void dfs(int u,int f){
        int c=col[u];
        sz[u]=1;LL o=sum[c];
        go(e,u){
            int v=to[e];
            if(v==f)continue;
            LL t=sum[c];
            dfs(v,u);
            ans[c]-=calc(sz[v]-(sum[c]-t));
            sz[u]+=sz[v];
        }
        sum[col[u]]=o+sz[u];
    }
    
    int main(){
        n=read();
        for(int i=1;i<=n;++i)col[i]=read(),ans[i]=n*(n+1)/2;
    
        for(int i=1,u,v;i<n;++i){
            u=read();
            v=read();
            add_edge(u,v);
        }
    
        dfs(1,0);
    
        for(int i=1;i<=n;++i){
            LL t=n-sum[i];
            ans[i]-=calc(t);
            cout<<ans[i]<<endl;
        }
    
        return 0;
    }
    View Code
  • 相关阅读:
    149. Max Points on a Line(js)
    148. Sort List(js)
    147. Insertion Sort List(js)
    146. LRU Cache(js)
    145. Binary Tree Postorder Traversal(js)
    144. Binary Tree Preorder Traversal(js)
    143. Reorder List(js)
    142. Linked List Cycle II(js)
    141. Linked List Cycle(js)
    140. Word Break II(js)
  • 原文地址:https://www.cnblogs.com/JohnRan/p/12773872.html
Copyright © 2011-2022 走看看