(mathcal{Description})
给定一棵 (n) 个结点的树,每个结点有一种颜色。记 (g(u,v)) 表示 (u) 到 (v) 简单路径上的颜色种数,求
[sum_{{p_n}}sum_{i=1}^{n-1}g(p_i,p_{i+1})
]
其中 ({p_n}) 表示 (1sim n) 的排列。
(nle10^5),答案对 ((10^9+7)) 取模。
(mathcal{Solution})
常见但不熟悉的 trick 题。(
不难想到分颜色计算贡献。而“路径上至少出现某种颜色”不好计算,考虑反向计算“路径上不包含某种颜色”的路径条数。
对于颜色 (c),删除所有颜色为 (c) 的结点,记得到联通块的大小为 (s_{1..m}),那么上述路径数量为 (sum_{i=1}^mfrac{s_i(s_i-1)}2)。称一个联通块深度最小的结点为其顶点,我们尝试在顶点处对这个联通块计数。可以看出,要不顶点是根,要不顶点父亲的颜色为 (c)。前者单独考虑,后者仅需用顶点子树大小减去顶点子树内以 (c) 颜色的结点作为根的子树大小。所以用前后作差的 trick:DFS 进入子树前,记录目前以 (c) 色点为根的子树大小和 (s),退出子树后,得到当前以 (c) 色点为根的子树大小和 (t)。那么该联通块的大小即为 (t-s)。
所以 DFS 一遍就可以啦。复杂度 (mathcal O(n))。
(mathcal{Code})
/* Clearink */
#include <cstdio>
#include <vector>
#include <algorithm>
#define rep( i, l, r ) for ( int i = l, repEnd##i = r; i <= repEnd##i; ++i )
#define per( i, r, l ) for ( int i = r, repEnd##i = l; i >= repEnd##i; --i )
inline int rint () {
int x = 0, f = 1; char s = getchar ();
for ( ; s < '0' || '9' < s; s = getchar () ) f = s == '-' ? -f : f;
for ( ; '0' <= s && s <= '9'; s = getchar () ) x = x * 10 + ( s ^ '0' );
return x * f;
}
const int MAXN = 1e5, MOD = 1e9 + 7;
int n, ecnt, ans, clr[MAXN + 5], head[MAXN + 5], siz[MAXN + 5], sum[MAXN + 5];
struct Edge { int to, nxt; } graph[MAXN * 2 + 5];
inline int mul ( const long long a, const int b ) { return a * b % MOD; }
inline int sub ( int a, const int b ) { return ( a -= b ) < 0 ? a + MOD : a; }
inline int add ( int a, const int b ) { return ( a += b ) < MOD ? a : a - MOD; }
inline void link ( const int s, const int t ) {
graph[++ecnt] = { t, head[s] };
head[s] = ecnt;
}
inline int count ( const int s ) { return ( s * ( s - 1ll ) >> 1 ) % MOD; }
inline void dfs ( const int u, const int fa ) {
siz[u] = 1, ++sum[clr[u]];
for ( int i = head[u], v; i; i = graph[i].nxt ) {
if ( ( v = graph[i].to ) ^ fa ) {
int s = sum[clr[u]];
dfs ( v, u ), siz[u] += siz[v];
int t = sum[clr[u]], dlt = t - s;
ans = add ( ans, count ( siz[v] - dlt ) );
sum[clr[u]] += siz[v] - dlt;
}
}
}
int main () {
n = rint ();
rep ( i, 1, n ) clr[i] = rint ();
for ( int i = 1, u, v; i < n; ++i ) {
u = rint (), v = rint ();
link ( u, v ), link ( v, u );
}
dfs ( 1, 0 );
rep ( c, 1, n ) {
if ( c ^ clr[1] ) {
ans = add ( ans, count ( n - sum[c] ) );
}
}
ans = sub ( mul ( n, count ( n ) ), ans );
int fct = 1;
rep ( i, 1, n - 2 ) fct = mul ( fct, i );
printf ( "%d", mul ( ans, mul ( 2, mul ( n - 1, fct ) ) ) );
return 0;
}