Count on a Tree II
寒假 DLS 讲过一个做法,写一下。
考虑随机在这个树上打 $ sqrt n $ 个点作为关键点,然后每个点向上跳找到第一个关键点的期望长度是 $ sqrt n $ 的。当然,也有真实的打点方法,用深度是 $ sqrt n $ 倍数并且向下深度大于等于根号的点来打,这样打出来点的个数是 $ sqrt n $ 的并且明显从每个点向上的长度也是 $ sqrt n $ 的。
然后我们先预处理出关键点间两两的答案,做法是从每个关键点跑一次 DFS 。
然后考虑一次询问,对于一次询问我们需要查询这两个点到关键点的路径上有没有什么颜色是不存在于关键点的路径上的。也就是说我们需要维护根到每个节点的某个颜色的数量。直接主席树维护得带 $ log $ 不优秀,可以考虑用根号平衡,修改 $ O(sqrt n) $ 查询 $ O(1) $ ,套个可持久化就好了。
还需要用 ST 表 $ O(1) $ 维护 LCA。这样做复杂度就是在线的 $ O(msqrt n) $ 了(虽然常数有点大)。
然后它成功在 BZOJ TLE了。。。
懒得调了,至少在 DB 和 SPOJ 过了。。
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
using namespace std;
#define MAXN 40006
#define B 206
#define swap( a , b ) ( (a) ^= (b) , (b) ^= (a) , (a) ^= (b) )
int n , m , blo , sz , bl;
int c[MAXN] , C[MAXN];
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn;
void ade( int u , int v ) {
to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}
int ps[B] , im[MAXN] , pre[B][B] , cid , ee;
int w[MAXN] , an;
void dfs( int u , int fa ) {
if( !w[c[u]] ) ++ an;
++ w[c[u]];
if( im[u] ) pre[cid][im[u]] = an;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
dfs( v , u );
}
-- w[c[u]];
if( !w[c[u]] ) -- an;
}
int kk[MAXN << 2][B] , idx = 0 , rt[MAXN];
void build( ) {
idx = rt[0] = 1;
for( int i = 1 ; ( i - 1 ) * bl < sz ; ++ i ) kk[1][i] = i + 1 , ++ idx;
}
void add( int rt , int old , int x ) {
memcpy( kk[rt] , kk[old] , sizeof kk[old] );
++ idx;
memcpy( kk[idx] , kk[kk[rt][( x - 1 ) / bl + 1]] , sizeof kk[1] );
kk[rt][( x - 1 ) / bl + 1] = idx;
++ kk[idx][( x - 1 ) % bl + 1];
}
int que( int rt , int x ) {
int t = kk[rt][( x - 1 ) / bl + 1];
return kk[t][( x - 1 ) % bl + 1];
}
int dfn[MAXN << 1] , dep[MAXN << 1] , pos[MAXN << 1] , en = 0 , d[MAXN] , par[MAXN];
int dfs1( int u , int fa ) {
dfn[++en] = u , pos[u] = en , dep[en] = d[u] , par[u] = fa;
rt[u] = ++ idx;
add( rt[u] , rt[fa] , c[u] );
int re = 0;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
d[v] = d[u] + 1;
re = max( re , dfs1( v , u ) );
dfn[++ en] = u , dep[en] = d[u];
}
if( re - d[u] >= blo && d[u] % blo == 0 ) ps[++ ee] = u , im[u] = ee;
return re ? re : d[u];
}
int lg[MAXN << 1];
int st[MAXN << 1][17];
void ST() {
for( int i = 1 ; i <= en ; ++ i ) lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for( int i = 1 ; i <= en ; ++ i ) st[i][0] = i;
for( int i = 1 ; ( 1 << i ) <= en ; ++ i )
for (int j = 1; j + (1 << i) - 1 <= en; j++)
st[j][i] = dep[st[j][i - 1]] < dep[st[j + (1 << (i - 1))][i - 1]] ? st[j][i - 1] : st[j + (1 << (i - 1))][i - 1];
}
int lca( int l , int r ) {
l = pos[l] , r = pos[r];
if( l > r ) swap( l , r );
int k = lg[r - l + 1] - 1;
return dep[st[l][k]] <= dep[st[r-(1<<k)+1][k]] ? dfn[st[l][k]] : dfn[st[r-(1<<k)+1][k]];
}
int col[MAXN] , cn , l;
int jump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
if( u == l ) return -1;
return jump( par[u] );
}
int rejjump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
return rejjump( par[u] );
}
int oc[MAXN];
int main() {
// freopen("10.in","r",stdin);
// freopen("ot","w",stdout);
cin >> n >> m;
blo = sqrt( n );
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&c[i]) , C[i] = c[i];
sort( C + 1 , C + 1 + n ); sz = unique( C + 1 , C + 1 + n ) - C - 1;
bl = sqrt( sz );
for( int i = 1 ; i <= n ; ++ i ) c[i] = lower_bound( C + 1 , C + 1 + sz , c[i] ) - C;
for( int i = 1 , u , v ; i < n ; ++ i ) {
scanf("%d%d",&u,&v);
ade( u , v ) , ade( v , u );
}
build( );
dfs1( 1 , 0 );
// cout << que( rt[3] , 3 ) << endl;
ST( );
for( int i = 1 ; i <= blo ; ++ i ) {
int p = ps[i]; cid = i;
dfs( p , p );
}
int u , v , psu , psv , flg = 0 , re , last = 0 , t , pr;
while( m-- ) {
scanf("%d%d",&u,&v);
u ^= last;
// cout << u << ' ' << v << endl;
l = lca( u , v );
cn = re = 0;
psu = jump( u ) , psv = jump( v );
if( psu == psv ) {
oc[c[l]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) oc[col[i]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
oc[c[l]] = 0;
} else if( ~psu && ~psv ) {
re = pre[psu][psv];
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psu]] , col[i] ) + que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) * 2 == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
} else {
if( ~psu ) swap( u , v ) , swap( psu , psv );
pr = l; t = cn;
while( !im[pr] ) {
pr = par[pr];
if( !oc[c[pr]] ) {
oc[c[pr]] = 1;
if (que(rt[ps[psv]], c[pr]) - que(rt[par[l]], c[pr]) == 0)
--re;
col[++ cn] = c[pr];
}
}
re += pre[im[pr]][psv];
for( int i = t + 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
cn = t;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
}
printf("%d
",last = re);
}
}