原文链接https://www.cnblogs.com/zhouzhendong/p/51Nod1868.html
题目传送门 - 51Nod1868
题意
给定一颗 $n$个点的树,每个点一个 $[1,n]$ 的颜色。设 $g(x,y)$ 表示 $x$ 到 $y$ 的树上路径上有几种颜色。
对于一个长度为 $n$ 的排列 $P[1cdots n]$ ,定义 $f(P)=sum_{i=1}^{n-1}g(P_i,P_{i+1})$ 。
现在求对于 $n!$ 个排列,他们的 $f(P)$ 之和 对 $10^9+7$ 取模后的值。
题解
首先我们考虑每一个 $g(x,y)$ 对于答案的贡献次数。
考虑捆绑法,把 $x$ 和 $y$ 看作一个整体,显然,它对答案的贡献次数为 $(n-1)!$ 。
于是答案就是
$$2 imes (n-1)!sum_{x=1}^{n}sum_{y=x+1}^{n} g(x,y)$$
前面的 $2 imes (n-1)!$ 很好办,现在主要要求后面的那个。
我们考虑对于每一个颜色分别处理。我们需要求出每一个颜色对答案的贡献。
记 $f(c,x,y)$ 表示路径 $x$~$y$ 上,如果有颜色 $c$ ,那么值为 $1$ ,否则为 $0$ 。则后面一半变成了:
$$sum_{c=1}^{n}sum_{x=1}^{n}sum_{y=x+1}^{n} f(c,x,y)$$
确定一种颜色之后,后面的显然非常好求,直接一个树形dp 就差不多了。但是这样的时间复杂度是炸掉的。于是我们需要一个数据结构来优化——虚树。
建出虚树,然后我们注意一下细节,统计一下就可以了。
这里推荐一个写的比较详细的虚树学习笔记:https://www.k-xzy.xyz/archives/4476
代码
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int N=200005,mod=1e9+7; int read(){ int x=0; char ch=getchar(); while (!isdigit(ch)) ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+ch-48,ch=getchar(); return x; } struct Gragh{ static const int M=N*2; int cnt,y[M],nxt[M],fst[N]; void clear(){ cnt=0; memset(fst,0,sizeof fst); } void add(int a,int b){ y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt; } }g,t; int n,c[N],Fac[N],Time=0,now_color,ans=0; int dfn[N],depth[N],size[N],fa[N][18],sqrsum[N]; int dirson[N],tot[N],st[N],top; vector <int> id[N]; LL calc(int x){ return 1LL*x*(x-1)/2; } void dfs(int x,int pre,int d){ dfn[x]=++Time,depth[x]=d,size[x]=1,fa[x][0]=pre,sqrsum[x]=0; for (int i=1;i<18;i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for (int i=g.fst[x];i;i=g.nxt[i]) if (g.y[i]!=pre){ int y=g.y[i]; dfs(y,x,d+1); size[x]+=size[y]; sqrsum[x]=(calc(size[y])+sqrsum[x])%mod; } } int LCA(int x,int y){ if (depth[x]<depth[y]) swap(x,y); for (int i=17;i>=0;i--) if (depth[x]-(1<<i)>=depth[y]) x=fa[x][i]; if (x==y) return x; for (int i=17;i>=0;i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i]; return fa[x][0]; } bool cmp(int a,int b){ return dfn[a]<dfn[b]; } void solve(int x){ int dx=dirson[x],sonsqr=tot[x]=0; for (int k=t.fst[x];k;k=t.nxt[k]){ int y=t.y[k],&dy=dirson[y]=y; for (int i=17;i>=0;i--) if (depth[dy]-(1<<i)>depth[x]) dy=fa[dy][i]; solve(y); tot[x]+=tot[y]; sonsqr=(calc(tot[y])+sonsqr)%mod; } if (c[x]==now_color){ tot[x]=size[x]; int xsqr=(calc(size[dx]-size[x])+sqrsum[x])%mod; ans=(calc(size[dx])-xsqr+ans)%mod; } else { ans=(calc(tot[x])-sonsqr+ans)%mod; for (int i=t.fst[x];i;i=t.nxt[i]){ int y=t.y[i],v=size[dx]-tot[x]+tot[y]-size[dirson[y]]; ans=(1LL*tot[y]*v+ans)%mod; } } } int main(){ n=read(); for (int i=Fac[0]=1;i<=n;i++) c[i]=read(),Fac[i]=1LL*Fac[i-1]*i%mod; g.clear(); for (int i=1;i<n;i++){ int a=read(),b=read(); g.add(a,b); g.add(b,a); } dfs(1,0,0); for (int i=1;i<=n;i++) id[i].clear(); for (int i=1;i<=n;i++) id[c[i]].push_back(i); t.clear(); for (int k=1;k<=n;k++){ if (id[k].size()<1) continue; sort(id[k].begin(),id[k].end(),cmp); st[top=1]=1,t.fst[1]=0; for (vector <int> :: iterator i=id[k].begin();i!=id[k].end();i++){ int x=*i; if (x==1) continue; int lca=LCA(x,st[top]); if (lca!=st[top]){ while (depth[st[top-1]]>depth[lca]) t.add(st[top-1],st[top]),top--; if (st[top-1]!=lca) t.fst[lca]=0,t.add(lca,st[top]),st[top]=lca; else t.add(lca,st[top--]); } t.fst[x]=0,st[++top]=x; } for (int i=1;i<top;i++) t.add(st[i],st[i+1]); now_color=k,dirson[1]=1; solve(1); } printf("%d ",2LL*(ans+mod)%mod*Fac[n-1]%mod); return 0; }