zoukankan      html  css  js  c++  java
  • 树分治learning

    学习了树的点分治,树的边分治似乎因为复杂度过高而并不出众,于是没学

    自己总结了一下 有些时候面对一些树上的结构 并且解决的是和路径有关的问题的时候 如果是多个询问 关注点在每次给出两个点,求一些关于这两个点之间路径的问题的时候,我们可以使用树链剖分,但是如果是给出一个单一的询问,但是很宏观 类似于求所有点对之间路径满足xx的数量,这时候我们可以树形dp做些什么 但是有时候会遇到一些树形dp难以解决的东西,类似于数组开不下,无法转移状态这种问题,就可以用树分治

    树分治基于一个思想 先确定一个点 找到所有过这个点的路径并判断 再对被这个点分开的连通块做同样的操作

    QAQ于是做了几道模板题

    POJ1741 模板题 求一棵树中 满足点对之间路径加和小于k的数量

    这里用到了一个动态规划 判断一个数组中 选两个数能加起来<k  做法是sort后维护两个指针

    其余的地方都很模板,需要注意的是要对root的所有son进行solve之后再重置lr

    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<vector>
    #include<queue>
    #include<map>
    #include<string>
    #include<iostream>
    #include<algorithm>
    #include<stack>
    using namespace std;
    #define L long long
    #define pb push_back
    #define lala printf("--------
    ");
    #define ph push
    #define rep(i, a, b) for (int i=a;i<=b;++i)
    #define dow(i, b, a) for (int i=b;i>=a;--i)
    #define fmt(i,n) if(i==n)printf("
    ");else printf(" ") ;
    #define fi first
    #define se second
    template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
    int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
    int n , m ;
    bool vis[10050] ;
    int dis[10050] ;
    ///---
    struct node {
        int v,w,nex;
    }b[10050*2];
    int tot ;
    int head[10050] ;
    void init() {
        flc(head,-1);
        tot = 0 ;
    }
    void add(int u,int v,int w) {
        tot ++ ;
        b[tot].v=v;
        b[tot].w=w;
        b[tot].nex=head[u];
        head[u]=tot;
    }
    ///---
    int son[10050] ;
    int getsize(int u,int fa) {
        son[u] = 1 ;
        for(int i = head[u] ; i != -1 ; i = b[i].nex) {
            int v = b[i].v ;
            if(v==fa || vis[v]) continue ;
            son[u]+=getsize(v,u);
        }
        return son[u];
    }
    int minn ;
    void getroot(int u,int fa,int &root,int siz) {
        int maxx = siz - son[u] ;
        for(int i = head[u] ; i != -1 ; i = b[i].nex) {
            int v=b[i].v ;
            if(v==fa || vis[v]) continue ;
            getroot(v,u,root,siz) ;
            maxx = max(maxx,son[v]) ;
        }
        if(minn == -1 || maxx < minn) {
             minn = maxx ;
             root = u ;
        }
    }
    ///---
    int l , r ;
    void getdepth(int u,int fa,int xd) {
        dis[++r] = xd ;
        for(int i = head[u] ; i != -1 ; i = b[i].nex) {
            int v = b[i].v ;
            int w = b[i].w ;
            if(v == fa || vis[v]) continue ;
            getdepth(v , u , xd + w) ;
        }
    }
    bool cmp(int a, int b) {
        return a<b;
    }
    int getdep(int l , int r) {
        if(l >= r) return 0 ;
        sort(dis + l , dis + r + 1 , cmp ) ;
        int res = 0 ;
        int le = l ;
        int ri = l-1 ;
        while(ri+1 <= r && dis[ri+1] + dis[le] <= m) {
            ri ++ ;
            res ++ ;
        }
        while(le + 1 <= r) {
            le ++ ;
            while(ri >= l && dis[ri] + dis[le] > m) ri -- ;
            res += ri - l + 1 ;
        }
    
        for(int i = l ; i <= r ; i ++ ) {
            if(dis[i]*2 <= m) res -- ;
        }
        return (res / 2) ;
    }
    ///---
    int solve(int u) {
        int siz = getsize(u , -1) ;
        minn = -1 ;
        int root = -1 ;
        getroot(u , -1 , root , siz) ;
        vis[root] = true ;
        int res = 0 ;
        for(int i = head[root] ; i != -1 ; i = b[i].nex) {
            int v = b[i].v ;
            if(vis[v]) continue ;
            int z = solve(v) ;
            res += z ;
        }
        l = 1 ;
        r = 0 ;
        for(int i = head[root] ; i != -1 ; i = b[i].nex) {
            int v = b[i].v ;
            int w = b[i].w ;
            if(vis[v]) continue ;
            getdepth(v , root , w) ;
            res -= getdep(l , r) ;
            l = r + 1 ;
        }
        res += getdep(1 , r) ;
        for(int i = 1 ; i <= r ; i ++ ) {
            if(dis[i] <= m) res ++ ;
            else break ;
        }
        vis[root] = false ;
        return res ;
    }
    
    int main () {
        while(scanf("%d%d" , &n, &m) != EOF) {
            if(n==0&&m==0) break ;
            init() ;
            rep(i,1,n-1) {
                int u,v,w;
                u=read();v=read();w=read();
                add(u,v,w);
                add(v,u,w);
            }
            memset(vis,false,sizeof(vis));
            int ans = solve(1) ;
            printf("%d
    " , ans) ;
        }
    }
    

    BZOJ 2152 求路径%3==0的点对的数量

    因为只是%3 所以比较容易些。。如果写树形dp的话会比较好写 维护一个dp[n][3]的数组就可以

    但是如果不是3是很大的数字 就得开dp[n][m] 如果开不下的话就得树分治

    /// 树形dp跑得又好又快QAQ

    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<vector>
    #include<queue>
    #include<map>
    #include<string>
    #include<iostream>
    #include<algorithm>
    #include<stack>
    using namespace std;
    #define L long long
    #define pb push_back
    #define lala printf("--------
    ");
    #define ph push
    #define rep(i, a, b) for (int i=a;i<=b;++i)
    #define dow(i, b, a) for (int i=b;i>=a;--i)
    #define fmt(i,n) if(i==n)printf("
    ");else printf(" ") ;
    #define fi first
    #define se second
    template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
    int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
    int n ;
    struct node {
        int v,w,nex;
    }b[20050 * 2];
    int head[20050];
    int tot ;
    void add(int u,int v,int w) {
        tot++;
        b[tot].v=v;b[tot].w=w;
        b[tot].nex=head[u];head[u]=tot;
    }
    void init() {
        flc(head,-1);
        tot=0;
    }
     
    int dp[20050][5] ;
    int ans ;
     
    void dfs(int u,int fa) {
        int a[4];
        flc(a,0) ;
        a[0] = 1 ;
        for(int i=head[u];i!=-1;i=b[i].nex) {
            int v=b[i].v ;
            int w=b[i].w ;
            if(v==fa) continue ;
            dfs(v,u) ;
            rep(j,0,2) {
                dp[u][(j+w)%3]+=dp[v][j] ;
            }
            rep(j,0,2) {
                int z=w+j;
                z%=3 ;
                if(z==0) {
                    ans += a[0]*dp[v][j] ;
                }
                if(z==1) {
                    ans += a[2]*dp[v][j] ;
                }
                if(z==2) {
                    ans += a[1]*dp[v][j] ;
                }
            }
            rep(j,0,2) {
                a[(j+w)%3]+=dp[v][j] ;
            }
        }
        dp[u][0] ++ ;
    }
     
    int main (){
        while(scanf("%d" , &n) != EOF) {
            init() ;
            flc(dp,0);
            ans = 0 ;
            rep(i,1,n-1){
                int u=read(),v=read(),w=read();
                add(u,v,w);add(v,u,w);
            }
            dfs(1,-1);
            int fm = n*n;
            ans *= 2 ;
            ans += n ;
            int gc = __gcd(fm,ans) ;
            fm/=gc ;
            ans/=gc ;
            printf("%d/%d",ans,fm) ;
        }
    }
    
    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<vector>
    #include<queue>
    #include<map>
    #include<string>
    #include<iostream>
    #include<algorithm>
    #include<stack>
    using namespace std;
    #define L long long
    #define pb push_back
    #define lala printf("--------
    ");
    #define ph push
    #define rep(i, a, b) for (int i=a;i<=b;++i)
    #define dow(i, b, a) for (int i=b;i>=a;--i)
    #define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex)
    #define fmt(i,n) if(i==n)printf("
    ");else printf(" ") ;
    #define fi first
    #define se second
    template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
    int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
     
    int n , m ;
    int ans ;
    bool vis[20050] ;
    int dis[20050] ;
    ///---
    struct node {
        int v,w,nex;
    }b[20050*2];
    int tot ;
    int head[20050] ;
    void init() {
        flc(head,-1);
        tot = 0 ;
    }
    void add(int u,int v,int w) {
        tot ++ ;
        b[tot].v=v;
        b[tot].w=w;
        b[tot].nex=head[u];
        head[u]=tot;
    }
    ///---
    int siz[20050];
    int getsize(int u,int fa) {
        siz[u] = 1 ;
        for(int i=head[u];i!=-1;i=b[i].nex) {
            int v=b[i].v;
            if(v==fa||vis[v]) continue ;
            siz[u]+=getsize(v,u);
        }
        return siz[u];
    }
    int minn ;
    void getroot(int u,int fa,int num,int &root) {
        int maxx=0;
        for(int i=head[u];i!=-1;i=b[i].nex){
            int v=b[i].v ;
            if(v==fa||vis[v]) continue ;
            getroot(v,u,num,root);
            maxx=max(maxx,siz[v]);
        }
        maxx=max(maxx,num-siz[u]);
        if(maxx<minn){
            minn=maxx;root=u;
        }
    }
    ///---
    int l,r;
    void getdepth(int u,int fa,int xd) {
        dis[++r]=xd ;
        rnode(i,u) {
            int v=b[i].v ;
            if(v==fa||vis[v]) continue ;
            int w=b[i].w ;
            getdepth(v,u,xd+w) ;
        }
    }
    int getdep(int l,int r) {
        if(l>r) return 0 ;
        int a[3] ; flc(a,0) ;
        rep(i,l,r) {
            a[dis[i]%3] ++ ;
        }
        int res = 0 ;
        rep(i,0,2) {
            rep(j,0,2) {
                if((i+j)%3==0) res += a[i]*a[j] ;
            }
        }
        return res ;
    }
    ///---
    int solve(int u) {
        int num = getsize(u,-1);
        minn = 999999999 ;
        int root ;
        getroot(u,-1,num,root);
        int ans = 0 ;
        vis[root]=true;
        rnode(i,root) {
            int v=b[i].v;
            if(vis[v]) continue ;
            ans += solve(v) ;
        }
        l = 1 ;
        r = 0 ;
        rnode(i,root) {
            int v=b[i].v;
            int w=b[i].w;
            if(vis[v]) continue ;
            getdepth(v,root,w) ;
            ans -= getdep(l,r) ;
            l = r + 1 ;
        }
        dis[++r] = 0 ;
        ans += getdep(1,r) ;
        vis[root] = false ;
        return ans ;
    }
     
     
     
     
    int main () {
        while(scanf("%d" , &n) != EOF) {
            init() ;
            rep(i,1,n-1) {
                int u=read();int v=read() ; int w = read();
                add(u,v,w) ; add(v,u,w) ;
            }
            memset(vis,false,sizeof(vis)) ;
            int ans = solve(1) ;
            int fm = n*n ;
            int g = __gcd(ans,fm) ;
            fm/=g ;
            ans/=g ;
            printf("%d/%d
    " , ans , fm) ;
        }
    }
    

    HDU 5977 大连的铜牌题 求包含所有颜色的路径的数目 k<=10

    这个题的颜色来源于点 在对root的son进行getdepth的时候 需要把root的颜色给带下去 因为我们用ans-root的同一个son内的孩子 里面肯定是包含root的颜色的 这个关系是或 所以可以直接或上去

    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<vector>
    #include<queue>
    #include<map>
    #include<string>
    #include<iostream>
    #include<algorithm>
    #include<stack>
    using namespace std;
    #define L long long
    #define pb push_back
    #define lala printf("--------
    ");
    #define ph push
    #define rep(i, a, b) for (L i=a;i<=b;++i)
    #define dow(i, b, a) for (L i=b;i>=a;--i)
    #define rnode(i,u) for(L i = head[u] ; i != -1 ; i = b[i].nex)
    #define fmt(i,n) if(i==n)printf("
    ");else printf(" ") ;
    #define fi first
    #define se second
    template<class T> inline void flc(T &A, L x){memset(A, x, sizeof(A));}
    L read(){L x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
    
    L n , k , m ;
    L ans ;
    bool vis[50050] ;
    L bl[50050] ;
    ///---
    struct node {
        L v,nex;
    }b[50050*2];
    L tot ;
    L head[50050] ;
    L dis[50050] ;
    void init() {
        flc(head,-1);
        tot = 0 ;
        memset(vis,false,sizeof(vis)) ;
    }
    void add(L u,L v) {
        tot ++ ;
        b[tot].v=v;
        b[tot].nex=head[u];
        head[u]=tot;
    }
    ///---
    vector<int>q[2050] ;
    void thefirst() {
        L z = (1<<k)-1 ;
        rep(i,0,1024) q[i].clear() ;
        rep(i,0,z) {
            rep(j,0,z) {
                if( (i|j) == z) {
                    q[i].pb(j) ;
                }
            }
        }
    }
    ///---
    L siz[50050];
    L getsize(L u,L fa) {
        siz[u] = 1 ;
        rnode(i,u){
            L v=b[i].v;
            if(v==fa||vis[v]) continue ;
            siz[u]+=getsize(v,u);
        }
        return siz[u];
    }
    L minn ;
    void getroot(L u,L fa,L num,L &root) {
        L maxx=0;
        rnode(i,u){
            L v=b[i].v ;
            if(v==fa||vis[v]) continue ;
            getroot(v,u,num,root);
            maxx=max(maxx,siz[v]);
        }
        maxx=max(maxx,num-siz[u]);
        if(minn==-1||maxx<minn){
            minn=maxx;root=u;
        }
    }
    ///---
    L l , r ;
    void getdepth(L u,L fa,L xd) {
        xd |= (1 << (bl[u]-1)) ;
        dis[++r] = xd ;
        rnode(i,u) {
            L v=b[i].v;
            if(vis[v] || v==fa) continue ;
            getdepth(v,u,xd) ;
        }
    }
    L mp[2050] ;
    L getdep(L l , L r) {
        if(l>r) return 0 ;
        flc(mp,0) ;
        L ans = 0 ;
        rep(i,l,r) {
            L x=dis[i] ;
            for(L i=0;i<q[x].size();i++){
                L y=q[x][i];
                ans += mp[y] ;
            }
            mp[x] ++ ;
        }
        return ans ;
    }
    ///---
    L solve(L u) {
        L siz = getsize(u,-1) ;
        L root = -1;
        minn = -1 ;
        getroot(u,-1,siz,root) ;
        vis[root]=true ;
        L ans = 0 ;
        rnode(i,root) {
            L v=b[i].v;
            if(vis[v]) continue ;
            ans += solve(v) ;
        }
        l = 1 ;
        r = 0 ;
        rnode(i,root) {
            L v=b[i].v ;
            if(vis[v]) continue ;
            getdepth(v,root,(1<<(bl[root]-1))) ;
            ans -= getdep(l,r) ;
            l = r + 1 ;
        }
        L x = (1<<(bl[root]-1)) ;
        L K = (1<<k)-1 ;
        rep(i,1,r) {
            if((x | dis[i]) == K) {
                ans ++ ;
            }
        }
        ans += getdep(1,r) ;
        vis[root]=false;
        return ans ;
    }
    
    int main () {
        while(scanf("%lld%lld" , &n,&k) != EOF) {
            init() ;
            thefirst() ;
            rep(i,1,n) bl[i] = read() ;
            rep(i,1,n-1) {
                L u=read(),v=read();
                add(u,v) ;
                add(v,u) ;
            }
            if(k == 1) {
                printf("%lld
    " , n*n) ;
                continue ;
            }
            L ans = solve(1) ;
            printf("%lld
    " , ans*2) ;
        }
    }
    

    学会了树分治之后开启了新技能“看见什么不明显DP的树上结构就觉得可以树分治” 感觉要分治算法学傻。。

    训练赛看到一个题 感觉树形DP不可做 于是想树分治 发现解决不了这个问题 但是感觉还是树分治 赛后发现果然

    uvaLive 6900 给出一棵树 每条边有cost与val 我有C 在树上选一条路径出来 使sum(cost) <= C时的最大val

    这个和加减不太一样 因为加减是可以通过对root的son来操作进行去重的 上一个大连的是进行或运算 也无可厚非 但是这个求max 是不可逆的

    但是我们本来就不需要去重 和以前模板思路不一样的是 我们保存dis数组中 每一个值来自哪个root的儿子R 然后对R排序 处理完一个R再搞另一个R 我们不需要排序 因为根据dfs的特性 相同的R一定有且只有一段 所以不需要sort 和之前的去重没有什么时间上的差别 因为省去了去重的时间 所以我想 时间应该会更快

    在第一道题里面 用一个sort+O(n)单调思想 其实sort就撑到nlogn了 所以之后的nlogn也是可以接受的 可以做一个树状数组 来维护前缀max

    因为不能开太大 所以进行一个离散化 时间也是nlogn的 最后的复杂度还是nlognlogn 虽然常数大点

    这种思想是泛用的 之前的几道题也可以这么做

    uvaLive 6900

    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<vector>
    #include<queue>
    #include<map>
    #include<string>
    #include<iostream>
    #include<algorithm>
    #include<stack>
    using namespace std;
    #define L long long
    #define pb push_back
    #define lala printf("--------
    ");
    #define ph push
    #define rep(i, a, b) for (int i=a;i<=b;++i)
    #define dow(i, b, a) for (int i=b;i>=a;--i)
    #define fmt(i,n) if(i==n)printf("
    ");else printf(" ") ;
    #define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex)
    #define fi first
    #define se second
    template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
    int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
    int n , m;
    struct node {
        int vol,val;
        int R ;
    }dis[20050];
    bool vis[20050];
    struct no {
        int v,vol,val,nex;
    }b[20050*2];
    int head[20050];
    int tot;
    void init() {
        flc(head,-1);
        tot=0;
    }
    void add(int u,int v,int vol,int val) {
        tot++ ;
        b[tot].v=v;b[tot].vol=vol;b[tot].val=val;
        b[tot].nex=head[u] ; head[u]=tot;
    }
    int V ;
    ///---
    int son[20050] ;
    int getsize(int u,int fa) {
        son[u] = 1 ;
        rnode(i,u) {
            int v=b[i].v;
            if(v==fa || vis[v]) continue ;
            son[u] += getsize(v,u) ;
        }
        return son[u] ;
    }
    int minn ;
    void getroot(int u,int fa,int &root,int siz) {
        int maxx = siz - son[u] ;
        rnode(i,u) {
            int v=b[i].v ;
            if(v==fa || vis[v]) continue ;
            if(son[v] > maxx) maxx = son[v] ;
            getroot(v,u,root,siz) ;
        }
        if(maxx < minn) {
            minn = maxx;
            root = u ;
        }
    }
    ///---
    int l,r ;
    void getdepth(int u,int fa,int xdvol,int xdval,int sp) {
        node tmp ;
        tmp.vol = xdvol ;
        tmp.val = xdval ;
        tmp.R = sp ;
        dis[++r] = tmp ;
        rnode(i,u) {
            int v = b[i].v ;
            if(v == fa || vis[v]) continue ;
            getdepth(v,u,xdvol + b[i].vol,xdval + b[i].val,sp) ;
        }
    }
    int c[40050] ;
    int lowbit(int x) {
        return (x&(-x)) ;
    }
    void segadd(int x,int ma) {
        while(x<=40000) {
            c[x]=max(c[x] , ma) ;
            x+=lowbit(x) ;
        }
    }
    int fin(int x) {
        int res = 0 ;
        while(x>0) {
            res = max(c[x],res);
            x-=lowbit(x) ;
        }
        return res ;
    }
    int calc(int l,int r) {
        if(l > r) return 0 ;
        flc(c,0) ;
        vector<int>ls ; ls.clear() ;
        rep(i,l,r) {
            ls.pb(dis[i].vol) ;
        }
        int res = 0 ;
        sort(ls.begin(),ls.end()) ;
        ls.erase(unique(ls.begin(),ls.end()) , ls.end()) ;
        for(int i = l ; i <= r ; i ++ ) {
            int j = i ;
            while(j <= r && dis[j].R == dis[i].R) {
                int z = dis[j].vol ;
                int val1 = dis[j].val ;
                if(z > V) {
                    j ++ ;
                    continue ;
                }
                int x = V - z ;
                int id = -2 ;
                int ll = 0 ;
                int rr = ls.size()-1 ;
                while(ll<=rr) {
                    int mid=(ll+rr)/2 ;
                    if(ls[mid]<=x) {
                        id=mid;
                        ll=mid+1;
                    }
                    else {
                        rr=mid-1;
                    }
                }
                if(id==-2){
                    j++;
                    continue ;
                }
                int rres = fin(id+1) ;
                res = max(res , rres + dis[j].val) ;
                j ++ ;
            }
            j -- ;
            rep(k,i,j) {
                int vol = dis[k].vol ;
                int val = dis[k].val ;
                int id = lower_bound(ls.begin(),ls.end(),vol)-ls.begin()+1 ;
                segadd(id,val) ;
            }
            i = j ;
        }
        return res ;
    }
    ///---
    int solve(int u) {
        int siz = getsize(u,-1) ;
        minn = 999999999 ;
        int root ;
        getroot(u,-1,root,siz) ;
        vis[root] = true ;
        int res = 0 ;
        rnode(i,root) {
            int v=b[i].v ;
            if(vis[v]) continue ;
            int x = solve(v) ;
            res = max(res,x) ;
        }
        l = 1 ;
        r = 0 ;
        rnode(i,root) {
            int v = b[i].v ;
            int vol = b[i].vol ; int val = b[i].val ;
            if(vis[v]) continue ;
            getdepth(v,root,vol,val,v) ;
        }
        res = max(res , calc(1,r)) ;
        rep(i,1,r) {
            if(dis[i].vol <= V) {
                res = max(dis[i].val , res) ;
            }
        }
        vis[root] = false ;
        return res ;
    }
    
    
    int main () {
        int t = read();
        while(t -- ) {
            n = read();
            init() ;
            rep(i,2,n) {
                int u=read(),v=read(),vol=read(),val=read() ;
                add(u,v,vol,val);
                add(v,u,vol,val);
            }
            V = read() ;
            memset(vis,false,sizeof(vis));
            int ans=solve(1) ;
            printf("%d
    " , ans) ;
        }
    }
    

    BZOJ 2152 用这种方法改了一下 发现由于必须sort 所以复杂度比之前的做法要多一个log

    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<vector>
    #include<queue>
    #include<map>
    #include<string>
    #include<iostream>
    #include<algorithm>
    #include<stack>
    using namespace std;
    #define L long long
    #define pb push_back
    #define lala printf("--------
    ");
    #define ph push
    #define rep(i, a, b) for (int i=a;i<=b;++i)
    #define dow(i, b, a) for (int i=b;i>=a;--i)
    #define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex)
    #define fmt(i,n) if(i==n)printf("
    ");else printf(" ") ;
    #define fi first
    #define se second
    template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));}
    int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
     
    int n , m ;
    int ans ;
    bool vis[20050] ;
    struct no {
        int x , R;
    }dis[20050] ;
    ///---
    struct node {
        int v,w,nex;
    }b[20050*2];
    int tot ;
    int head[20050] ;
    void init() {
        flc(head,-1);
        tot = 0 ;
    }
    void add(int u,int v,int w) {
        tot ++ ;
        b[tot].v=v;
        b[tot].w=w;
        b[tot].nex=head[u];
        head[u]=tot;
    }
    ///---
    int siz[20050];
    int getsize(int u,int fa) {
        siz[u] = 1 ;
        for(int i=head[u];i!=-1;i=b[i].nex) {
            int v=b[i].v;
            if(v==fa||vis[v]) continue ;
            siz[u]+=getsize(v,u);
        }
        return siz[u];
    }
    int minn ;
    void getroot(int u,int fa,int num,int &root) {
        int maxx=0;
        for(int i=head[u];i!=-1;i=b[i].nex){
            int v=b[i].v ;
            if(v==fa||vis[v]) continue ;
            getroot(v,u,num,root);
            maxx=max(maxx,siz[v]);
        }
        maxx=max(maxx,num-siz[u]);
        if(maxx<minn){
            minn=maxx;root=u;
        }
    }
    ///---
    int l,r;
    void getdepth(int u,int fa,int xd,int sp) {
        no tmp ;
        tmp.x = xd ;
        tmp.R = sp ;
        dis[++r] = tmp ;
        rnode(i,u) {
            int v=b[i].v ;
            if(v==fa||vis[v]) continue ;
            int w=b[i].w ;
            getdepth(v,u,xd+w,sp) ;
        }
    }
    int getdep(int l,int r) {
        if(l>r) return 0;
        int res = 0 ;
        int a[5] ; flc(a,0) ;
        for(int i = l ; i <= r ; i ++ ) {
            int j = i ;
            while(j <= r && dis[j].R==dis[i].R) {
                int x = dis[j].x % 3 ;
                int ned = 3 - x ;
                ned %= 3 ;
                res += a[ned] ;
                j ++ ;
            }
            j -- ;
            rep(k,i,j) {
                int x = dis[k].x % 3 ;
                a[x] ++ ;
            }
            i = j ;
        }
        return res ;
    }
    ///---
    int solve(int u) {
        int num = getsize(u,-1);
        minn = 999999999 ;
        int root ;
        getroot(u,-1,num,root);
        int ans = 0 ;
        vis[root]=true;
        rnode(i,root) {
            int v=b[i].v;
            if(vis[v]) continue ;
            ans += solve(v) ;
        }
        l = 1 ;
        r = 0 ;
        rnode(i,root) {
            int v=b[i].v;
            int w=b[i].w;
            if(vis[v]) continue ;
            getdepth(v,root,w,v) ;
        }
        ans += getdep(1,r) ;
        rep(i,1,r) {
            if(dis[i].x % 3 == 0) ans ++ ;
        }
        vis[root] = false ;
        return ans ;
    }
     
     
     
     
    int main () {
        while(scanf("%d" , &n) != EOF) {
            init() ;
            rep(i,1,n-1) {
                int u=read();int v=read() ; int w = read();
                add(u,v,w) ; add(v,u,w) ;
            }
            memset(vis,false,sizeof(vis)) ;
            int ans = solve(1) ;
            ans *= 2;
            ans += n ;
            int fm = n*n ;
            int g = __gcd(ans,fm) ;
            fm/=g ;
            ans/=g ;
            printf("%d/%d
    " , ans , fm) ;
        }
    }
    
  • 相关阅读:
    NDOC中文支持及入门用法
    网页代码常用小技巧
    SOCKET通讯点滴
    自动备份程序目录
    MySql.Data.dll Microsoft.Web.UI.WebControls.dll下载
    c#:获取IE地址栏中的URL
    比较好的单例登录模式(参考网友)
    FreeTextBox使用详解
    2005自定义控件显示基准线
    连接字符串大全
  • 原文地址:https://www.cnblogs.com/rayrayrainrain/p/7598082.html
Copyright © 2011-2022 走看看