对于询问dist,树链剖分搞之,把边权转化到点上,然后注意细节就好(我在代码里标出来了,为了这个细节,wa了一屏)
对于询问kth,可以先求出两点(x和y)的lca,然后判断第k个数字是在x到lca的路径上还是y到lca的路径上,确定之后,倍增的寻找就好了~
View Code
1 #include <iostream> 2 #include <cstring> 3 #include <cstdlib> 4 #include <algorithm> 5 #include <cstdio> 6 7 #define N 50000 8 #define M 100000 9 10 using namespace std; 11 12 int head[N],next[M],to[M],len[M]; 13 int n,tot,cnt; 14 int fa[N],son[M],top[N],dat[N],sum[N<<2],dep[N],sz[N],pre[N],bh[N]; 15 int f[N][22],bit[22]; 16 int q[M]; 17 18 inline void init() 19 { 20 memset(head,-1,sizeof head); cnt=2; tot=0; 21 memset(son,0,sizeof son); 22 memset(fa,0,sizeof fa); 23 memset(f,0,sizeof f); 24 memset(sum,0,sizeof sum); 25 bit[0]=1; 26 for(int i=1;i<=20;i++) bit[i]=bit[i-1]<<1; 27 } 28 29 inline void prep() 30 { 31 int h=1,t=2,sta; 32 q[1]=1; dep[1]=1; 33 while(h<t) 34 { 35 sta=q[h++]; sz[sta]=1; 36 for(int i=head[sta];~i;i=next[i]) 37 if(fa[sta]!=to[i]) 38 { 39 fa[to[i]]=sta; 40 f[to[i]][0]=sta; 41 pre[to[i]]=i^1; 42 dep[to[i]]=dep[sta]+1; 43 q[t++]=to[i]; 44 } 45 } 46 for(int j=t-1;j>=1;j--) 47 { 48 sta=q[j]; 49 for(int i=head[sta];~i;i=next[i]) 50 if(fa[sta]!=to[i]) 51 { 52 sz[sta]+=sz[to[i]]; 53 if(sz[to[i]]>sz[son[sta]]) son[sta]=to[i]; 54 } 55 } 56 for(int i=1;i<t;i++) 57 { 58 sta=q[i]; 59 if(son[fa[sta]]==sta) top[sta]=top[fa[sta]]; 60 else top[sta]=sta; 61 } 62 } 63 64 inline void rewrite() 65 { 66 for(int i=1;i<=n;i++) 67 if(top[i]==i) 68 for(int j=i;j;j=son[j]) 69 { 70 bh[j]=++tot; 71 dat[tot]=len[pre[j]]; 72 } 73 } 74 75 inline void lcainit() 76 { 77 for(int j=1;j<=20;j++) 78 for(int i=1;i<=n;i++) 79 f[i][j]=f[f[i][j-1]][j-1]; 80 } 81 82 inline void pushup(int x) 83 { 84 sum[x]=sum[x<<1]+sum[x<<1|1]; 85 } 86 87 inline void build(int u,int L,int R) 88 { 89 if(L==R) {sum[u]=dat[L];return;} 90 int MID=(L+R)>>1; 91 build(u<<1,L,MID); build(u<<1|1,MID+1,R); 92 pushup(u); 93 } 94 95 inline void add(int u,int v,int w) 96 { 97 to[cnt]=v; len[cnt]=w; next[cnt]=head[u]; head[u]=cnt++; 98 } 99 100 inline void read() 101 { 102 init(); 103 scanf("%d",&n); 104 for(int i=1,a,b,c;i<n;i++) 105 { 106 scanf("%d%d%d",&a,&b,&c); 107 add(a,b,c); add(b,a,c); 108 } 109 prep(); 110 rewrite(); 111 build(1,1,tot); 112 lcainit(); 113 } 114 115 inline int querysum(int u,int L,int R,int l,int r) 116 { 117 if(l<=L&&R<=r) return sum[u]; 118 int MID=(L+R)>>1,res=0; 119 if(l<=MID) res+=querysum(u<<1,L,MID,l,r); 120 if(MID<r) res+=querysum(u<<1|1,MID+1,R,l,r); 121 return res; 122 } 123 124 inline int getsum(int x,int y) 125 { 126 int res=0; 127 while(top[x]!=top[y]) 128 { 129 if(dep[top[x]]<dep[top[y]]) swap(x,y); 130 res+=querysum(1,1,tot,bh[top[x]],bh[x]); 131 x=fa[top[x]]; 132 } 133 if(x==y) return res;//这句话好坑啊!把边权转移到点权上时会出现这个问题! 134 if(bh[x]>bh[y]) swap(x,y); 135 res+=querysum(1,1,tot,bh[son[x]],bh[y]);//细节 136 return res; 137 } 138 139 inline int getlca(int x,int y) 140 { 141 if(dep[x]<dep[y]) swap(x,y); 142 for(int i=20;i>=0;i--) 143 if(dep[f[x][i]]>=dep[y]) x=f[x][i]; 144 if(x==y) return x; 145 for(int i=20;i>=0;i--) 146 if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; 147 return f[x][0]; 148 } 149 150 inline int getlen(int x,int lca) 151 { 152 int res=0; 153 for(int i=20;i>=0;i--) 154 if(dep[f[x][i]]>=dep[lca]) x=f[x][i],res+=bit[i]; 155 return res; 156 } 157 158 inline int getnum(int x,int p) 159 { 160 int res=0; 161 for(int i=20;i>=0;i--) 162 if(res+bit[i]<=p) x=f[x][i],res+=bit[i]; 163 return x; 164 } 165 166 inline int getkth(int x,int y,int p) 167 { 168 int lca=getlca(x,y); 169 int lx=getlen(x,lca)+1; 170 int ly=getlen(y,lca)+1; 171 if(lx>=p) return getnum(x,p-1); 172 return getnum(y,lx+ly-p-1); 173 } 174 175 inline void go() 176 { 177 char str[10];int a,b,c; 178 while(scanf("%s",str)) 179 { 180 if(str[1]=='O') break; 181 if(str[0]=='K') 182 { 183 scanf("%d%d%d",&a,&b,&c); 184 printf("%d\n",getkth(a,b,c)); 185 } 186 else 187 { 188 scanf("%d%d",&a,&b); 189 printf("%d\n",getsum(a,b)); 190 } 191 } 192 puts(""); 193 } 194 195 int main() 196 { 197 int cas;scanf("%d",&cas); 198 while(cas--) read(),go(); 199 return 0; 200 }