枚举最大点i,统计以i为最大点有多少个连通块,由于一个连通块可能有许多个最大点,我们强制选编号最小的那个即可。
#include<bits/stdc++.h> #define file(s) freopen(s".in","r",stdin);freopen(s".out","w",stdout); #define P 1000000007 #define mid (l+r>>1) #define N 1100000 #define lb(x) (x&(-x)) #define inf 999999999 #define M 1658561 #define int long long #define mem(x) memset(x,0,sizeof(x)); using namespace std; int n,d,a[N],dp[N],ans,rt,to[N],nxt[N],head[N],cnt; void add(int x,int y){ to[++cnt]=y; nxt[cnt]=head[x]; head[x]=cnt; } void dfs(int x,int fa){ dp[x]=1; for(int i=head[x];i;i=nxt[i]){ if(to[i]==fa) continue; if(a[to[i]]<a[rt]-d||a[to[i]]>a[rt]||(a[to[i]]==a[rt]&&to[i]<rt)) continue; dfs(to[i],x);(dp[x]+=dp[x]*dp[to[i]])%=P; } } signed main(){ int x,y; scanf("%lld%lld",&d,&n); for(int i=1;i<=n;i++) scanf("%lld",&a[i]); for(int i=1;i<n;i++) scanf("%lld%lld",&x,&y),add(x,y),add(y,x); for(int i=1;i<=n;i++) rt=i,dfs(i,0),(ans+=dp[i])%=P; printf("%lld ",ans); return 0; }