$dfs$序,线段树。
可以统计每一个节点作为$root$的子树上对答案的贡献,可以将树转换成序列。问题就变成了一段区间上求小于等于某个值的数有几个。用线段树记录排好序之后的区间序列,询问的时候,属于询问区间的每个节点二分一下统计答案即可。
#pragma comment(linker, "/STACK:1024000000,1024000000") #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #include<vector> #include<map> #include<set> #include<queue> #include<stack> #include<iostream> using namespace std; typedef long long LL; const double pi=acos(-1.0),eps=1e-6; void File() { freopen("D:\in.txt","r",stdin); freopen("D:\out.txt","w",stdout); } template <class T> inline void read(T &x) { char c=getchar(); x=0; while(!isdigit(c)) c=getchar(); while(isdigit(c)) {x=x*10+c-'0'; c=getchar();} } const int maxn=100010; int T,n,h[maxn],sz,r[maxn],root; LL k,v[maxn]; struct Edge { int u,v,nx; }e[maxn]; LL a[2*maxn],L[2*maxn],R[2*maxn]; vector<int>s[8*maxn]; void add(int u,int v) { e[sz].u=u; e[sz].v=v; e[sz].nx=h[u]; h[u]=sz++; } void dfs(int x) { sz++; a[sz]=v[x]; L[x]=sz; for(int i=h[x];i!=-1;i=e[i].nx) dfs(e[i].v); sz++; a[sz]=v[x]; R[x]=sz; } void build(int l,int r,int rt) { if(l==r) { s[rt].push_back(a[l]); return; } int m=(l+r)/2; build(l,m,2*rt); build(m+1,r,2*rt+1); int sum=0,p1=0,p2=0; while(sum<r-l+1) { if(p1<s[2*rt].size()&&p2<s[2*rt+1].size()) { if(s[2*rt][p1]<s[2*rt+1][p2]) s[rt].push_back(s[2*rt][p1]), p1++; else s[rt].push_back(s[2*rt+1][p2]), p2++; } else if(p1<s[2*rt].size()) s[rt].push_back(s[2*rt][p1]), p1++; else s[rt].push_back(s[2*rt+1][p2]), p2++; sum++; } } int get(int L,int R,LL num,int l,int r,int rt) { if(L<=l&&r<=R) { int left=0,right=r-l,pos=-1; while(left<=right) { int mid=(left+right)/2; if((LL)s[rt][mid]>num) right=mid-1; else left=mid+1,pos=mid; } return pos+1; } int m=(l+r)/2,x1=0,x2=0; if(L<=m) x1=get(L,R,num,l,m,2*rt); if(R>m) x2=get(L,R,num,m+1,r,2*rt+1); return x1+x2; } int main() { scanf("%d",&T); while(T--) { scanf("%d%lld",&n,&k); for(int i=1;i<=n;i++) scanf("%lld",&v[i]); memset(h,-1,sizeof h); memset(r,sz=0,sizeof r); for(int i=0;i<8*maxn;i++) s[i].clear(); for(int i=1;i<=n-1;i++) { int u,v; scanf("%d%d",&u,&v); add(u,v); r[v]++; } for(int i=1;i<=n;i++) if(r[i]==0) root=i; sz=0; dfs(root); build(1,2*n,1); LL Ans=0; for(int i=1;i<=n;i++) { if(L[i]+1==R[i]) continue; if(v[i]==0) { Ans=Ans+(R[i]-L[i]-1); continue; } Ans=Ans+get(L[i]+1,R[i]-1,k/v[i],1,2*n,1); } printf("%lld ",Ans/2); } return 0; }