传送
题面:有(n(n leqslant 20))种花,每种花的数量(f_i(f_i leqslant 10^{12}))已知,现在要取(s(s leqslant 10^{14}))朵,求方案数。
继续练习生成函数。
这道题的生成函数非常好推,就是$$prod_{i=1}^n frac{1-x^{f_i+1}}{1-x}$$
然后用广义二项式定理进行化简:
[egin{align*}
prod_{i=1}^n frac{1-x^{f_i+1}}{1-x} &= frac{prod_{i=1}^n (1-x^{f_i+1})}{(1-x)^n}\
&= prod_{i=1}^n (1-x^{f_i+1})sum_{k=0}^infty C_{n+k-1}^{n-1}x^k\
end{align*}]
因为(n)只有(20),所以前面的乘积可以暴力乘,但因为次数很大,所以不能用次数作为数组下标,开一个结构体模拟一下即可。
这样乘完后是一个有限项的多项式,我们枚举这个多项式的所有项(a_ix^i),结合后面的展开式,答案就是(sumlimits_{i} a_i * C_{n + s - i - 1} ^{n-1}).
需要注意的是,组合数(C_{n}^m)中的(n)很大,(m)很小,所以应该暴力算出来(frac{n!}{m!}),再乘以(frac1{m!}).
这道题的主要收获是,算法竞赛中的数学题,不一定要推出(O(1))的式子,可以结合数据范围在适当位置进行暴力求解。
#include<bits/stdc++.h>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define In inline
typedef long long ll;
typedef double db;
const int maxn = 25;
const ll mod = 1e9 + 7;
In ll read()
{
ll ans = 0;
char ch = getchar(), las = ' ';
while(!isdigit(ch)) las = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(las == '-') ans = -ans;
return ans;
}
In void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
ll f[maxn], inv[maxn];
In ll ADD(ll a, ll b) {return a + b < mod ? a + b : a + b - mod;}
In ll quickpow(ll a, ll b)
{
ll ret = 1;
for(; b; b >>= 1, a = a * a % mod)
if(b & 1) ret = ret * a % mod;
return ret;
}
In ll C(ll n, int m)
{
ll ret = 1;
for(ll i = n - m + 1; i <= n; ++i) ret = ret * (i % mod) % mod;
return ret * inv[m] % mod;
}
int n;
ll s, a[maxn];
#define pr pair<ll, ll>
#define mp make_pair
#define F first
#define S second
vector<pr> v[2];
In ll solve()
{
int o = 0;
v[o].push_back(mp(0, 1));
for(int i = 1; i <= n; ++i, o ^= 1)
{
v[o ^ 1] = v[o];
for(int j = 0, k = 0; j < (int)v[o].size(); ++j)
{
ll tp = v[o][j].F + a[i] + 1;
while(k < (int)v[o ^ 1].size() && v[o ^ 1][k].F < tp) ++k;
if(k < (int)v[o ^ 1].size() && v[o ^ 1][k].F == tp) v[o ^ 1][k].S = ADD(v[o ^ 1][k].S, mod - v[o][j].S);
else v[o ^ 1].push_back(mp(tp, mod - v[o][j].S));
}
v[o] = v[o ^ 1];
}
ll ans = 0;
for(auto x : v[o])
if(x.F <= s) ans = ADD(ans, x.S * C(n - 1 + s - x.F, n - 1) % mod);
return ans;
}
int main()
{
f[0] = inv[0] = 1;
for(int i = 1; i < maxn; ++i) f[i] = f[i - 1] * i % mod;
inv[maxn - 1] = quickpow(f[maxn - 1], mod - 2);
for(int i = maxn - 2; i; --i) inv[i] = inv[i + 1] * (i + 1) % mod;
n = read(), s = read();
for(int i = 1; i <= n; ++i) a[i] = read();
write(solve()), enter;
return 0;
}