给一颗n个点的有点权的树,有m个询问,对于每个询问u,v,k,首先将点u到点v的最短路径上的所有点按顺序编号,u的编号为1,求树链上所有点的新编号cnt满足cnt%k==0的点的权值的最大值。n,m,k<=10^5
根据k的大小分成两部分处理。原问题可转化为 deep[i] % k = a / b 。
对于k较大的,直接暴力,按照dfs序用一个栈记录下所经过的点,对于每个询问的点不停往上爬。
对于k较小的,将询问按照k分类。对于每一种k,将所有点按照dep[i] % k分类,将每个点树链剖分后hash下来的坐标再按照dep[i] % k映射到一起,用线段树进行维护。
每次查询deep[i] % k = a 时,相当于在某个区间查询最大值。
#include <map> #include <set> #include <stack> #include <queue> #include <cmath> #include <ctime> #include <string> #include <vector> #include <cstdio> #include <cstdlib> #include <cstring> #include <cassert> #include <iostream> #include <algorithm> #pragma comment(linker,"/STACK:102400000,102400000") using namespace std; #define N 100008 #define LL long long #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 #define clr(x,v) memset(x,v,sizeof(x)); #define bitcnt(x) __builtin_popcount(x) #define rep(x,y,z) for (int x=y;x<=z;x++) #define repd(x,y,z) for (int x=y;x>=z;x--) const int mo = 1000000007; const int inf = 0x3f3f3f3f; const int INF = 2000000000; /**************************************************************************/ int T,n,m,k,sum,block,stop,cq2,tmp,label; int lt[N],a[N],dep[N],st[N],ans[N],size[N],son[N],rk[N],w[N],top[N],mx[N<<2],cl[N],cr[N],fuck[N];; int f[N][20]; struct line{ int u,v,nt; }eg[N*2]; struct que{ int u,v,lca,k,id,flag; }; vector <que> q1[N],q2[N]; void add(int u,int v){ eg[++sum]=(line){u,v,lt[u]}; lt[u]=sum; } void init(){ clr(lt,0); clr(ans,0); sum=1; tmp=0; } void dfs(int u,int fa){ f[u][0]=fa; dep[u]=dep[fa]+1; size[u]=1; son[u]=0; for (int i=lt[u];i;i=eg[i].nt){ int v=eg[i].v; if (v==fa) continue; dfs(v,u); size[u]+=size[v]; if (size[v]>size[son[u]]) son[u]=v; } } void dfs_2(int u,int tp){ top[u]=tp; w[u]=++tmp; rk[tmp]=u; if (son[u]) dfs_2(son[u],tp); for (int i=lt[u];i;i=eg[i].nt){ int v=eg[i].v; if (v==f[u][0] || v==son[u]) continue; dfs_2(v,v); } } int lca(int x,int y){ if (dep[x]<dep[y]) swap(x,y); int d=dep[x]-dep[y]; repd(i,18,0) if (d & (1<<i)) x=f[x][i]; if (x==y) return x; repd(i,18,0) if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; return f[x][0]; } void work_1(int u){ st[++stop]=u; for (int i=0;i<q1[u].size();i++){ int v=q1[u][i].v,x=q1[u][i].lca,k=q1[u][i].k,id=q1[u][i].id,flag=q1[u][i].flag; if (flag) { int tmp=stop-k+1; while (tmp>=0 && dep[st[tmp]]>=dep[x]){ ans[id]=max(ans[id],a[st[tmp]]); tmp-=k; } } else { int tmp=dep[x]+ k - (dep[v]-dep[x]+1) % k ; while (tmp<=stop && dep[st[tmp]]<=dep[u]){ ans[id]=max(ans[id],a[st[tmp]]); tmp+=k; } } } for (int i=lt[u];i;i=eg[i].nt){ int v=eg[i].v; if (v==f[u][0]) continue; work_1(v); } st[stop--]=0; } int query(int L,int R,int l,int r,int rt){ if (L<=l && r<=R){ return mx[rt]; } int m=(l+r)>>1; int res=0; if (L <= m) res=max(res,query(L,R,lson)); if (m < R) res=max(res,query(L,R,rson)); return res; } void build(int l,int r,int rt){ mx[rt]=0; if (l==r){ mx[rt]=a[rk[fuck[l]]]; return; } int m=(l+r)>>1; build(lson); build(rson); mx[rt]=max(mx[rt<<1],mx[rt<<1|1]); } int calc(int L, int R, int k) { int l = lower_bound(fuck + cl[k], fuck + cr[k] + 1, L) - fuck; int r = upper_bound(fuck + cl[k], fuck + cr[k] + 1, R) - fuck - 1; return l <= r ? query(l,r,1,n,1) : 0; } void find(int tp,int x,int cl,int id){ while (top[x]!=top[tp]){ ans[id]=max(ans[id],calc(w[top[x]],w[x],cl)); x=f[top[x]][0]; } ans[id]=max(ans[id],calc(w[tp],w[x],cl)); } vector <int> E[N]; void work_2(int k){ rep(i,0,k) E[i].clear(); rep(i,1,n) E[(dep[rk[i]]) % k].push_back(i); label=0; rep(i,0,k-1){ cl[i]=label+1; for (int j=0;j<E[i].size();j++) fuck[++label]=E[i][j]; cr[i]=label; } build(1,n,1); for (int i=0;i<q2[k].size();i++){ int u=q2[k][i].u,v=q2[k][i].v,x=q2[k][i].lca,id=q2[k][i].id; int tmp1=(dep[x]+(dep[u]-dep[x]+1) % k) % k; int tmp2=(dep[x]+k-(dep[u]-dep[x]+1) % k) % k; find(x,u,tmp1,id); find(x,v,tmp2,id); } } int main(){ int cas=0; scanf("%d",&T); while (T--){ printf("Case #%d: ",++cas ); init(); scanf("%d%d",&n,&m); rep(i,1,n) scanf("%d",&a[i]); rep(i,1,n-1){ int u,v; scanf("%d%d",&u,&v); add(u,v); add(v,u); } dfs(1,0); dfs_2(1,1); rep(j,1,18) rep(i,1,n) f[i][j]=f[f[i][j-1]][j-1]; block=20; cq2=0; rep(i,1,n) q1[i].clear(); rep(i,1,n) q2[i].clear(); rep(i,1,m){ int u,v,k; scanf("%d%d%d",&u,&v,&k); int x=lca(u,v); if (k>block) { q1[u].push_back((que){u,v,x,k,i,1}); q1[v].push_back((que){v,u,x,k,i,0}); } else { q2[k].push_back((que){u,v,x,k,i,1}); } } work_1(1); rep(i,1,block) if (q2[i].size()) work_2(i); rep(i,1,m) printf("%d ",ans[i]); } }