学了差不多一星期的主席树+树链剖分,再来看这题发现其实是个板子题
一开始想复杂了,以为要用类似求树上第k大的树上差分思想来解决这道题,但其实树链上<=k的元素个数其实直接可以用树链剖分来求
具体是把每条树链放到主席树上询问一下求和就好了
#include<bits/stdc++.h> using namespace std; #define maxn 100006 struct Edge{int to,nxt,w;}edge[maxn<<1]; int b[maxn],n,m,a[maxn],head[maxn],tot; void init(){memset(head,-1,sizeof head);tot=0;} void addedge(int u,int v,int w){ edge[tot].to=v;edge[tot].nxt=head[u];edge[tot].w=w;head[u]=tot++; } struct Node{int lc,rc,sum;}T[maxn*25]; int siz,rt[maxn]; int build(int l,int r){ int now=++siz; T[now].lc=T[now].rc=T[now].sum=0; if(l==r)return now; int mid=l+r>>1; T[now].lc=build(l,mid); T[now].rc=build(mid+1,r); return now; } int update(int last,int pos,int l,int r){//更新到pos点 int now=++siz; T[now]=T[last];T[now].sum++; if(l==r)return now; int mid=l+r>>1; if(pos<=mid)T[now].lc=update(T[last].lc,pos,l,mid); else T[now].rc=update(T[last].rc,pos,mid+1,r); return now; } int query(int st,int ed,int L,int R,int l,int r){ if(L<=l && R>=r)return T[ed].sum-T[st].sum; int mid=l+r>>1,res=0; if(L<=mid)res+=query(T[st].lc,T[ed].lc,L,R,l,mid); if(R>mid)res+=query(T[st].rc,T[ed].rc,L,R,mid+1,r); return res; } int f[maxn],son[maxn],d[maxn],size[maxn]; void dfs1(int x,int pre,int deep){ f[x]=pre;size[x]=1;d[x]=deep; for(int i=head[x];i!=-1;i=edge[i].nxt){ int y=edge[i].to; if(y==pre)continue; a[y]=edge[i].w; dfs1(y,x,deep+1); size[x]+=size[y]; if(size[y]>size[son[x]])son[x]=y; } } int id[maxn],rk[maxn],idx,top[maxn]; void dfs2(int x,int tp){ top[x]=tp;id[x]=++idx;rk[idx]=x; if(son[x])dfs2(son[x],tp); for(int i=head[x];i!=-1;i=edge[i].nxt){ int y=edge[i].to; if(y!=son[x] && y!=f[x])dfs2(y,y); } } int Query(int x,int y,int pos){ int res=0; while(top[x]!=top[y]){ if(d[top[x]]<d[top[y]])swap(x,y); res+=query(rt[id[top[x]]-1],rt[id[x]],1,pos,1,m); x=f[top[x]]; } if(id[x]>id[y])swap(x,y); res+=query(rt[id[x]],rt[id[y]],1,pos,1,m); return res; } int main(){int q;init(); cin>>n>>q;int u,v,w,k; for(int i=1;i<n;i++){ scanf("%d%d%d",&u,&v,&w); addedge(u,v,w);addedge(v,u,w); } a[1]=0x3f3f3f3f; siz=0;dfs1(1,1,1);dfs2(1,1);//树剖 for(int i=1;i<=n;i++)b[++m]=a[i]; sort(b+1,b+1+m); m=unique(b+1,b+1+m)-(b+1); rt[0]=build(1,m); for(int i=1;i<=idx;i++){ int pos=lower_bound(b+1,b+1+m,a[rk[i]])-b; rt[i]=update(rt[i-1],pos,1,m); } while(q--){ scanf("%d%d%d",&u,&v,&k); int pos=upper_bound(b+1,b+1+m,k)-(b+1); if(pos==0){puts("0");continue;} else cout<<Query(u,v,pos)<<' '; } }