题目链接:http://poj.org/problem?id=1988
有n个元素,开始每个元素自己 一栈,有两种操作,将含有元素x的栈放在含有y的栈的顶端,合并为一个栈。第二种操作是询问含有x元素下面有多少个元素。
经典的带权并查集,cnt表示包含这个元素的集合中所有元素个数,dis表示这个元素离最上面元素的个数(距离)。
看代码领会一下吧。
1 #include <iostream> 2 #include <cstring> 3 #include <cstdio> 4 using namespace std; 5 int par[int(3e4 + 5)] , cnt[int(3e4 + 5)] , dis[int(3e4 + 5)]; 6 //cnt[i]表示i所在点的大小,dis[i]表示i离最上面节点的距离(个数) 7 int Find(int n) { 8 if(par[n] == n) 9 return n; 10 int temp = Find(par[n]); //先算出n的父节点的dis 11 dis[n] += dis[par[n]]; 12 par[n] = temp; //路径压缩 13 return temp; 14 } 15 16 void Union(int u , int v) { 17 int fu = Find(u) , fv = Find(v); 18 if(fu == fv) 19 return ; 20 par[fv] = fu; 21 dis[fv] = cnt[fu]; //Find函数中dis[fv]并没有回溯增加过 22 cnt[fu] += cnt[fv]; //总个数相加 23 } 24 25 int main() 26 { 27 int n , u , v; 28 char q[3]; 29 while(~scanf("%d" , &n)) { 30 memset(dis , 0 , sizeof(dis)); 31 for(int i = 1 ; i <= 3e4 ; ++i) { 32 cnt[i] = 1; 33 par[i] = i; 34 } 35 for(int i = 0 ; i < n ; ++i) { 36 scanf("%s" , q); 37 if(q[0] == 'M') { 38 scanf("%d %d" , &u , &v); 39 Union(u , v); 40 } 41 else { 42 scanf("%d" , &u); 43 int x = Find(u); //回溯累加一次 44 printf("%d " , cnt[x] - dis[u] - 1); 45 } 46 } 47 } 48 return 0; 49 }