zoukankan      html  css  js  c++  java
  • Count on a Tree II

    Count on a Tree II

    寒假 DLS 讲过一个做法,写一下。

    考虑随机在这个树上打 $ sqrt n $ 个点作为关键点,然后每个点向上跳找到第一个关键点的期望长度是 $ sqrt n $ 的。当然,也有真实的打点方法,用深度是 $ sqrt n $ 倍数并且向下深度大于等于根号的点来打,这样打出来点的个数是 $ sqrt n $ 的并且明显从每个点向上的长度也是 $ sqrt n $ 的。

    然后我们先预处理出关键点间两两的答案,做法是从每个关键点跑一次 DFS 。

    然后考虑一次询问,对于一次询问我们需要查询这两个点到关键点的路径上有没有什么颜色是不存在于关键点的路径上的。也就是说我们需要维护根到每个节点的某个颜色的数量。直接主席树维护得带 $ log $ 不优秀,可以考虑用根号平衡,修改 $ O(sqrt n) $ 查询 $ O(1) $ ,套个可持久化就好了。

    还需要用 ST 表 $ O(1) $ 维护 LCA。这样做复杂度就是在线的 $ O(msqrt n) $ 了(虽然常数有点大)。

    然后它成功在 BZOJ TLE了。。。

    懒得调了,至少在 DB 和 SPOJ 过了。。

    #include "iostream"
    #include "algorithm"
    #include "cstring"
    #include "cstdio"
    #include "cmath"
    using namespace std;
    #define MAXN 40006
    #define B 206
    #define swap( a , b ) ( (a) ^= (b) , (b) ^= (a) , (a) ^= (b) )
    int n , m , blo , sz , bl;
    int c[MAXN] , C[MAXN];
    int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn;
    void ade( int u , int v ) {
        to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
    }
    int ps[B] , im[MAXN] , pre[B][B] , cid , ee;
    int w[MAXN] , an;
    void dfs( int u , int fa ) {
        if( !w[c[u]] ) ++ an;
        ++ w[c[u]];
        if( im[u] ) pre[cid][im[u]] = an;
        for( int i = head[u] ; i ; i = nex[i] ) {
            int v = to[i];
            if( v == fa ) continue;
            dfs( v , u );
        }
        -- w[c[u]];
        if( !w[c[u]] ) -- an;
    }
    
    int kk[MAXN << 2][B] , idx = 0 , rt[MAXN];
    void build( ) {
        idx = rt[0] = 1;
        for( int i = 1 ; ( i - 1 ) * bl < sz ; ++ i ) kk[1][i] = i + 1 , ++ idx;
    }
    void add( int rt , int old , int x ) {
        memcpy( kk[rt] , kk[old] , sizeof kk[old] );
        ++ idx;
        memcpy( kk[idx] , kk[kk[rt][( x - 1 ) / bl + 1]] , sizeof kk[1] );
        kk[rt][( x - 1 ) / bl + 1] = idx;
        ++ kk[idx][( x - 1 ) % bl + 1];
    }
    int que( int rt , int x ) {
        int t = kk[rt][( x - 1 ) / bl + 1];
        return kk[t][( x - 1 ) % bl + 1];
    }
    
    int dfn[MAXN << 1] , dep[MAXN << 1] , pos[MAXN << 1] , en = 0 , d[MAXN] , par[MAXN];
    int dfs1( int u , int fa ) {
        dfn[++en] = u , pos[u] = en , dep[en] = d[u] , par[u] = fa;
        rt[u] = ++ idx;
        add( rt[u] , rt[fa] , c[u] );
        int re = 0;
        for( int i = head[u] ; i ; i = nex[i] ) {
            int v = to[i];
            if( v == fa ) continue;
            d[v] = d[u] + 1;
            re = max( re , dfs1( v , u ) );
            dfn[++ en] = u , dep[en] = d[u];
        }
        if( re - d[u] >= blo && d[u] % blo == 0 ) ps[++ ee] = u , im[u] = ee;
        return re ? re : d[u];
    }
    int lg[MAXN << 1];
    int st[MAXN << 1][17];
    void ST() {
        for( int i = 1 ; i <= en ; ++ i ) lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
        for( int i = 1 ; i <= en ; ++ i ) st[i][0] = i;
        for( int i = 1 ; ( 1 << i ) <= en ; ++ i )
            for (int j = 1; j + (1 << i) - 1 <= en; j++)
                st[j][i] = dep[st[j][i - 1]] < dep[st[j + (1 << (i - 1))][i - 1]] ? st[j][i - 1] : st[j + (1 << (i - 1))][i - 1];
    }
    int lca( int l , int r ) {
        l = pos[l] , r = pos[r];
        if( l > r ) swap( l , r );
        int k = lg[r - l + 1] - 1;
        return dep[st[l][k]] <= dep[st[r-(1<<k)+1][k]] ? dfn[st[l][k]] : dfn[st[r-(1<<k)+1][k]];
    }
    int col[MAXN] , cn , l;
    int jump( int u ) {
        if( im[u] ) return im[u];
        col[++ cn] = c[u];
        if( u == l ) return -1;
        return jump( par[u] );
    }
    int rejjump( int u ) {
        if( im[u] ) return im[u];
        col[++ cn] = c[u];
        return rejjump( par[u] );
    }
    int oc[MAXN];
    int main() {
    //    freopen("10.in","r",stdin);
    //    freopen("ot","w",stdout);
        cin >> n >> m;
        blo = sqrt( n );
        for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&c[i]) , C[i] = c[i];
        sort( C + 1 , C + 1 + n ); sz = unique( C + 1 , C + 1 + n ) - C - 1;
        bl = sqrt( sz );
        for( int i = 1 ; i <= n ; ++ i ) c[i] = lower_bound( C + 1 , C + 1 + sz , c[i] ) - C;
        for( int i = 1 , u , v ; i < n ; ++ i ) {
            scanf("%d%d",&u,&v);
            ade( u , v ) , ade( v , u );
        }
        build( );
        dfs1( 1 , 0 );
    //    cout << que( rt[3] , 3 ) << endl;
        ST( );
        for( int i = 1 ; i <= blo ; ++ i ) {
            int p = ps[i]; cid = i;
            dfs( p , p );
        }
        int u , v , psu , psv , flg = 0 , re , last = 0 , t , pr;
        while( m-- ) {
            scanf("%d%d",&u,&v);
            u ^= last;
    //        cout << u << ' ' << v << endl;
            l = lca( u , v );
            cn = re = 0;
            psu = jump( u ) , psv = jump( v );
            if( psu == psv ) {
                oc[c[l]] = 1 , ++ re;
                for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) oc[col[i]] = 1 , ++ re;
                for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
                oc[c[l]] = 0;
            } else if( ~psu && ~psv ) {
                re = pre[psu][psv];
                for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
                        oc[col[i]] = 1;
                        if( que( rt[ps[psu]] , col[i] ) + que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) * 2 == 0 )
                            ++ re;
                    }
                for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
            } else { 
                if( ~psu ) swap( u , v ) , swap( psu , psv );
                pr = l; t = cn;
                while( !im[pr] ) {
                    pr = par[pr];
                    if( !oc[c[pr]] ) {
                        oc[c[pr]] = 1;
                        if (que(rt[ps[psv]], c[pr]) - que(rt[par[l]], c[pr]) == 0)
                            --re;
                        col[++ cn] = c[pr];
                    }
                }
                re += pre[im[pr]][psv];
                for( int i = t + 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
                cn = t;
                for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
                        oc[col[i]] = 1;
                        if( que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) == 0 )
                            ++ re;
                    }
                for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
            }
            printf("%d
    ",last = re);
        }
    }
    
    
  • 相关阅读:
    6-ESP8266 SDK开发基础入门篇--操作系统入门使用
    5-ESP8266 SDK开发基础入门篇--了解一下操作系统
    【java规则引擎】基本语法和相关属性介绍
    【eclipse】 怎么解决java.lang.NoClassDefFoundError错误
    【java规则引擎】java规则引擎搭建开发环境
    【4】JDK和CGLIB生成动态代理类的区别
    【java规则引擎】一个基于drools规则引擎实现的数学计算例子
    【3】SpringMVC的Controller
    设计模式之禅之代理模式
    【java规则引擎】规则引擎RuleBase中利用观察者模式
  • 原文地址:https://www.cnblogs.com/yijan/p/12378129.html
Copyright © 2011-2022 走看看