题面
各位请看这个式子
[displaystyle
leftvert s
ightvert * val_s
]
设(val_s)为(w_1 + w_2 + cdots +w_s)
所以上式就可以表示为
[displaystyle
egin{aligned}
w_1+w_2&+cdots +w_s\
w_1+w_2&+cdots +w_s\
w_1+w_2&+cdots +w_s\
&cdots \
w_1+w_2&+cdots +w_s\
end{aligned}
]
可以发现, (w_1)总共被贡献了s次, 分别是由(w_1, w_2, cdots, w_s)贡献的, 所以我们可以这样理解每个数对最终答案的贡献
一. 由自己所贡献, 那么这样的方案数为
[displaystyle
egin{Bmatrix}n\kend{Bmatrix}
]
即自己在任意一种将(n)个数分为(k)个集合的方案中都对答案有贡献
二. 由其他数所贡献, 这样的方案数为
[displaystyle
(n - 1)egin{Bmatrix}n - 1\kend{Bmatrix}
]
即这个数跟任意一个其他的数在同一个集合中都会有贡献, 选一个其他的数共有((n - 1))种选择, 然后对于某个其他的数, 其分在(k)个集合中任一个都会有对该数有一次贡献, 然后分在(k)个集合中的任一个集合的方案数是(egin{Bmatrix}n-1\kend{Bmatrix})
所以就是上面那个式子
故最后的答案是
[displaystyle
ans = (egin{Bmatrix}n\kend{Bmatrix}+(n-1)egin{Bmatrix}n-1\kend{Bmatrix})sum_{i=1}^{n}w_i
]
Code
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#define itn int
#define reaD read
#define mod 1000000007
#define N 200005
using namespace std;
itn n, m, jc[N], inv[N], sum, ans;
inline int read()
{
int x = 0, w = 1; char c = getchar();
while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
return x * w;
}
int fpow(int x, int y)
{
int res = 1;
while(y)
{
if(y & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod;
y >>= 1;
}
return res;
}
int C(int n, int m)
{
int res = jc[n];
res = 1ll * res * inv[m] % mod;
res = 1ll * res * inv[n - m] % mod;
return res;
}
int Stirl(int n, int m)
{
int res = 0;
for(int i = 0; i <= m; i++)
{
int num = i & 1 ? mod - 1 : 1;
num = 1ll * num * C(m, i) % mod;
num = 1ll * num * fpow(m - i, n) % mod;
res = 1ll * (res + num) % mod;
}
res = 1ll * res * inv[m] % mod;
return res;
}
int main()
{
n = read(); m = read();
for(int i = 1; i <= n; i++) sum = 1ll * (sum + read()) % mod;
for(int i = (inv[1] = jc[1] = inv[0] = jc[0] = 1) + 1; i <= n; i++)
{
jc[i] = 1ll * jc[i - 1] * i % mod;
inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
}
for(int i = 1; i <= n; i++) inv[i] = 1ll * inv[i] * inv[i - 1] % mod;
ans = 1ll * (ans + Stirl(n, m)) % mod; ans = 1ll * (ans + 1ll * (n - 1) * Stirl(n - 1, m) % mod) % mod;
ans = 1ll * ans * sum % mod;
printf("%d
", ans);
return 0;
}