xor
There is a tree with nn nodes. For each node, there is an integer value a_iai, (1 le a_i le 1,000,000,0001≤ai≤1,000,000,000 for 1 le i le n1≤i≤n). There is qq queries which are described as follow: Assume the value on the path from node aa to node bb is t_0, t_1, cdots t_mt0,t1,⋯tm. You are supposed to calculate t_0t0 xor t_ktk xor t_{2k}t2k xor ... xor t_{pk}tpk (pk le m)(pk≤m).
Input Format
There are multi datasets. (sum n le 50,000, sum q le 500,000)(∑n≤50,000,∑q≤500,000).
For each dataset: In the first n-1n−1 lines, there are two integers u,vu,v, indicates there is an edge connect node uuand node vv.
In the next nn lines, There is an integer a_iai (1 le a_i le 1,000,000,0001≤ai≤1,000,000,000).
In the next qq lines, There is three integers a,ba,b and kk. (1 le a,b,k le n1≤a,b,k≤n).
Output Format
For each query, output an integer in one line, without any additional space.
样例输入
5 6 1 5 4 1 2 1 3 2 19 26 0 8 17 5 5 1 1 3 2 3 2 1 5 4 2 3 4 4 1 4 5
样例输出
17 19 26 25 0 19
题目来源
【题意】给您一棵树,每个节点有一个权值,q次询问,每次给出u,v,k,求从u到v的路径中,从u开始,每隔k个节点亦或一下的结果。
【分析】根号分治,大于根号n的暴力跳(倍增跳),小于根号n的用数组存起来,复杂度最高为O(N)*sqrt(N)*log(N).
#include <bits/stdc++.h> #define inf 0x3f3f3f3f #define met(a,b) memset(a,b,sizeof a) #define pb push_back #define mp make_pair #define rep(i,l,r) for(int i=(l);i<=(r);++i) #define inf 0x3f3f3f3f using namespace std; typedef long long ll; const int N = 5e4+50;; const int M = 255; const int mod = 19260817; const int mo=123; const double pi= acos(-1.0); typedef pair<int,int>pii; int n,q,sz; int a[N],fa[N][25],up[N][M+10],dep[N]; vector<int>edg[N],vec; int find(int u,int k){ for(int i=19;i>=0;i--){ if(k>>i&1){ u=fa[u][i]; if(u==0)return 0; } } return u; } void dfs(int u,int f){ fa[u][0]=f; for(int i=1;i<20;i++){ fa[u][i]=fa[fa[u][i-1]][i-1]; } for(int i=1;i<=sz;i++){ up[u][i]=a[u]; int v=find(u,i); up[u][i]^=up[v][i]; } for(int i=0;i<edg[u].size();i++){ int v=edg[u][i]; if(v==f)continue; dep[v]=dep[u]+1; dfs(v,u); } } int LCA(int u,int v){ int U=u,V=v; if(dep[u]<dep[v])swap(u,v); for(int i=19;i>=0;i--){ if(dep[fa[u][i]]>=dep[v]){ u=fa[u][i]; } } if(u==v)return (u); for(int i=19;i>=0;i--){ if(fa[u][i]!=fa[v][i]){ u=fa[u][i];v=fa[v][i]; } } return (fa[u][0]); } int solve(int u,int v,int k,int lca){ int res=(dep[u]+dep[v]-2*dep[lca])%k; int U=u,V=v; v=find(v,res); int ans=0; while(dep[u]>=dep[lca]){ ans^=a[u]; u=find(u,k); if(!u)break; } if(V==lca||dep[v]<=dep[lca]||v==0)return ans; V=v; while(dep[v]>=dep[lca]){ ans^=a[v]; v=find(v,k); if(!v)break; } if((dep[U]-dep[lca])%k==0&&(dep[V]-dep[lca])%k==0)ans^=a[lca]; return ans; } void init(){ met(fa,0);met(up,0); for(int i=0;i<N;i++){ edg[i].clear(); } } int main(){ while(~scanf("%d%d",&n,&q)){ init(); sz=round(sqrt(n)); for(int i=1,u,v;i<n;i++){ scanf("%d%d",&u,&v); edg[u].pb(v);edg[v].pb(u); } for(int i=1;i<=n;i++)scanf("%d",&a[i]); dep[1]=1; dfs(1,0); while(q--){ int u,v,k; scanf("%d%d%d",&u,&v,&k); int lca=LCA(u,v),ans=0;; if(k>sz){ ans=solve(u,v,k,lca); } else { int dis=dep[u]-dep[lca]; int s=(dis/k+1)*k; int x=find(u,s); ans=up[u][k]^up[x][k]; int res=(dep[u]+dep[v]-2*dep[lca])%k; if(lca!=v&&dep[v]-dep[lca]>res){ v=find(v,res); dis=dep[v]-dep[lca]; s=(dis/k+1)*k; x=find(v,s); ans^=up[v][k]^up[x][k]; if((dep[u]-dep[lca])%k==0&&(dep[v]-dep[lca])%k==0)ans^=a[lca]; } } printf("%d ",ans); } } return 0; }