题目链接在这里:Problem - G - Codeforces
这道题主要要解决的问题就是避免同样的k条链被选择了多次。所以如何通过某一个特征保证同样的k条链只被选择一次是非常关键的,我们发现,在从下向上回溯的时候,如果这k条链有一条已经到了lca,那么再往上遍历的话这条链就一定不会再出现了,所以我们就计算存在至少一条边已经到了lca的k条边就行了。具体实现的话就是枚举每个lca再统计。
注意阶乘和逆元可以预处理,二者预处理的时候注意要处理0
取模的时候如果有减号要先加上MOD再取模
1 #include "bits/stdc++.h" 2 using namespace std; 3 typedef long long LL; 4 const int MAX=3e5+5; 5 const int MOD=1e9+7; 6 LL n,m,k,t,fa[MAX][35],deep[MAX],jie[MAX],nij[MAX],an1[MAX],an2[MAX],ans; 7 LL tot,head[MAX],adj[MAX<<1],nxt[MAX<<1]; 8 inline LL ksm(LL x,LL y){ 9 LL an=1; 10 while (y){ 11 if (y&1) an=an*x%MOD; 12 x=x*x%MOD; 13 y>>=1; 14 } 15 return an; 16 } 17 void addedge(LL u,LL v){ 18 tot++; 19 adj[tot]=v; 20 nxt[tot]=head[u]; 21 head[u]=tot; 22 } 23 void dfs(int x,int ff){ 24 LL i,j; 25 for (i=1;i<=30;i++){ 26 if (deep[x]<(1<<i)) break; 27 fa[x][i]=fa[fa[x][i-1]][i-1]; 28 } 29 for (i=head[x];i;i=nxt[i]){ 30 if (adj[i]==ff) continue; 31 deep[adj[i]]=deep[x]+1; 32 fa[adj[i]][0]=x; 33 dfs(adj[i],x); 34 } 35 } 36 LL lca(LL x,LL y){ 37 LL i,j,zt; 38 if (deep[x]<deep[y]) swap(x,y); 39 zt=deep[x]-deep[y]; 40 for (i=30;i>=0;i--) 41 if (zt&(1<<i)) 42 x=fa[x][i]; 43 for (i=30;i>=0;i--) 44 if (fa[x][i]!=fa[y][i]) 45 x=fa[x][i],y=fa[y][i]; 46 return x==y?x:fa[x][0]; 47 } 48 LL C(LL x,LL y){ 49 if (x<y) return 0; 50 return jie[x]*nij[y]%MOD*nij[x-y]%MOD; 51 } 52 void dfs2(LL x,LL ff){ 53 LL i,j; 54 for (i=head[x];i;i=nxt[i]){ 55 if (adj[i]==ff) continue; 56 dfs2(adj[i],x); 57 an2[x]+=an2[adj[i]]; 58 } 59 } 60 int main(){ 61 freopen ("g.in","r",stdin); 62 freopen ("g.out","w",stdout); 63 LL i,j,zt,u,v,ll; 64 scanf("%lld",&t); 65 jie[0]=nij[0]=jie[1]=1; 66 for (i=2;i<MAX;i++) jie[i]=jie[i-1]*i%MOD; 67 nij[MAX-1]=ksm(jie[MAX-1],MOD-2); 68 for (i=MAX-2;i>=1;i--) nij[i]=nij[i+1]*(i+1)%MOD; 69 while (t--){ 70 scanf("%lld%lld%lld",&n,&m,&k); 71 memset(head,0,sizeof(head)); 72 memset(fa,0,sizeof(fa)); 73 tot=0; 74 for (i=1;i<n;i++){ 75 scanf("%lld%lld",&u,&v); 76 addedge(u,v); 77 addedge(v,u); 78 } 79 memset(an1,0,sizeof(an1)); 80 memset(an2,0,sizeof(an2)); 81 memset(deep,0,sizeof(deep)); 82 dfs(1,0); 83 for (i=1;i<=m;i++){ 84 scanf("%lld%lld",&u,&v); 85 ll=lca(u,v); 86 an1[ll]++; 87 an2[u]++;an2[v]++; 88 an2[ll]--;an2[fa[ll][0]]--; 89 } 90 dfs2(1,0); 91 ans=0; 92 for (i=1;i<=n;i++) 93 ans=(ans+(C(an2[i],k)-C(an2[i]-an1[i],k)+MOD)%MOD)%MOD; 94 printf("%lld ",ans); 95 } 96 return 0; 97 }