设 T 为一棵有根树,我们做如下的定义:
• 设 a 和 b 为 T 中的两个不同节点。如果 a 是 b 的祖先,那么称“a 比 b 不知道高明到哪里去了”。
• 设 a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定常数 x,那么称“a 与 b 谈笑风生”。
给定一棵 n 个节点的有根树 T,节点的编号为 1 ∼ n,根节点为 1 号节点。你需要回答 q 个询问,询问给定两个整数 p 和 k,问有多少个有序三元组 (a; b; c) 满足:
-
a、 b 和 c 为 T 中三个不同的点,且 a 为 p 号节点;
-
a 和 b 都比 c 不知道高明到哪里去了;
-
a 和 b 谈笑风生。这里谈笑风生中的常数为给定的 k
Solution
其实看标题就知道这题的解法了。
发现我们求的是min(deep[p]-1,k)*(size[p]-1)加上以p为根的子树里所有深度在deep[p]+1~deep[p]+k的点的size和。
这个东西怎么求?
按照dfs序建立主席树,下标为节点深度。
Code
#include<iostream> #include<cstdio> #define N 300003 using namespace std; typedef long long ll; int size[N],deep[N],mad,head[N],tot,top,ji,dfn[N],hui[N],L[N*22],R[N*22],re[N],n,q,x,y,T[N]; ll tr[N*22]; struct ds{ int n,to; }e[N<<1]; inline void add(int u,int v){ e[++tot].n=head[u]; e[tot].to=v; head[u]=tot; } void dfs(int u,int fa){ size[u]=1;deep[u]=deep[fa]+1; mad=max(mad,deep[u]); dfn[u]=++top; re[top]=u; for(int i=head[u];i;i=e[i].n){ int v=e[i].to; if(v==fa)continue; dfs(v,u); size[u]+=size[v]; } hui[u]=top; } int build(int l,int r){ int p=++ji; if(l==r)return p; int mid=(l+r)>>1; L[p]=build(l,mid);R[p]=build(mid+1,r); return p; } int update(int pre,int l,int r,int x,ll y){ int p=++ji; L[p]=L[pre];R[p]=R[pre];tr[p]=tr[pre]+y; if(l==r)return p; int mid=(l+r)>>1; if(mid>=x)L[p]=update(L[pre],l,mid,x,y); else R[p]=update(R[pre],mid+1,r,x,y); return p; } int rd(){ int x=0;char c=getchar(); while(!isdigit(c))c=getchar(); while(isdigit(c)){ x=(x<<1)+(x<<3)+(c^48); c=getchar(); } return x; } ll query(int pre,int now,int l,int r,int LL,int RR){ if(l>=LL&&r<=RR)return tr[now]-tr[pre]; int mid=(l+r)>>1; ll ans=0; if(mid>=LL)ans+=query(L[pre],L[now],l,mid,LL,RR); if(mid<RR)ans+=query(R[pre],R[now],mid+1,r,LL,RR); return ans; } int main(){ n=rd();q=rd(); int pu,k; for(int i=1;i<n;++i)x=rd(),y=rd(),add(x,y),add(y,x); dfs(1,0); T[0]=build(1,mad); for(int i=1;i<=n;++i)T[i]=update(T[i-1],1,mad,deep[re[i]],size[re[i]]-1); while(q--){ pu=rd();k=rd();ll num=0; if(deep[pu]!=mad)num=query(T[dfn[pu]],T[hui[pu]],1,mad,deep[pu]+1,min(deep[pu]+k,mad)); printf("%lld ",((ll)min((ll)deep[pu]-1,(ll)k))*((ll)size[pu]-1)+num); } return 0; }