最近在外面集训,生成函数这东西被提了好多次,于是我就觉得这东西应该挺重要的,最好学一下。
但是现在我只会写一些入门题。
啊对了,这有几篇不错的博客:
1.生成函数(母函数)——目前最全的讲解
2.组合数学之三 —— 生成函数
3.兔哥的趣谈生成函数 =v=
那么就先看第一道题吧。
HDU1028
像第一篇博客一样,举个例子:比如数字1,(x ^3)表示用三个1有1种方案;或者数字3,(x ^ 6)表示用2个3有一种方案。
也就是说,前面的系数表示方案,后面的指数表示凑成的这个值。
那么这个的生成函数就是(f(x) = (1 + x + x ^ 2 + x ^ 3 + ldots + x ^ n) + (1 + x ^ 2 + x ^ 4 + ldots + x ^ {2n}) + ldots + (1 + x ^ n + x ^ {2n} + ldots + x ^ {n * n}))
把这个算出来后,指数为(n)的那一项的系数就是答案。
至于怎么算,就是暴力的多项式乘以多项式,单次(O(n ^ 2)),总复杂度(O(n ^ 3))。
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 125;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n;
ll f[2][maxn];
int main()
{
while(scanf("%d", &n) != EOF)
{
Mem(f, 0);
fill(f[0], f[0] + n + 1, 1);
for(int i = 2; i <= n; ++i)
{
for(int j = 0; j <= n; ++j)
for(int k = 0; j + k <= n; k += i)
f[1][j + k] += f[0][j];
for(int j = 0; j <= n; ++j) f[0][j] = f[1][j], f[1][j] = 0;
}
write(f[0][n]), enter;
}
return 0;
}
[HDU1085](http://acm.hdu.edu.cn/showproblem.php?pid=1085) 这道题是不是几乎和上一题一样? 嗯是的,只不过这一次规定了数目。 那也没影响啊,$f(x) = (1 +x +x ^ 2 + ldots + x ^ {num_1}) + (1 + x ^ 2 + x ^ 4 +ldots + x ^ {2 * num_2}) + (1 + x ^ 5 + x ^ {10} + ldots + x ^ {5 * num_5})$ 求出来后,答案就是最小的系数为0的项。 maxn大小要开到8000多。 注意这道题用bool类型即可,用int乘的话还会爆,导致像我一样TLE什么的。 ```c++ #include
int a[3] = {1, 2, 5}, num[3];
bool f[2][maxn];
int main()
{
while(scanf("%d%d%d", &num[0], &num[1], &num[2]) && num[0] + num[1] + num[2] > 0)
{
Mem(f, 0);
fill(f[0], f[0] + num[0] + 1, 1);
for(int i = 1; i <= 2; ++i)
{
for(int j = 0; j < maxn; ++j)
for(int k = 0; k <= num[i]; ++k)
if(f[0][j]) f[1][j + k * a[i]] = 1;
for(int j = 0; j < maxn; ++j) f[0][j] = f[1][j], f[1][j] = 0;
}
for(int i = 0; i < maxn; ++i) if(!f[0][i]) {write(i), enter; break;}
}
return 0;
}
</br>
[HDU1171](https://vjudge.net/problem/HDU-1171)
题意就是有n种东西,每一个的权值为v,有m个。将这些东西分成两拨,让权值和尽量接近。
</br>
生成函数很容易就能写出来,求完后从$frac{sum}{2}$开始向前查,第一个出现$i$的就是答案。
要注意到每一循环sum是累加的,否则会超时……
```c++
#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("")
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxN = 3e5 + 5;
inline ll read()
{
ll ans = 0;
char ch = getchar(), last = ' ';
while(!isdigit(ch)) last = ch, ch = getchar();
while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
if(last == '-') ans = -ans;
return ans;
}
inline void write(ll x)
{
if(x < 0) x = -x, putchar('-');
if(x >= 10) write(x / 10);
putchar(x % 10 + '0');
}
int n, sum = 0;
bool f[2][maxN];
int main()
{
while(scanf("%d", &n) && n > 0)
{
Mem(f, 0);
int v = read(), num = read();
sum = v * num;
for(int i = 0; i <= num; ++i) f[0][v * i] = 1;
for(int i = 1; i < n; ++i)
{
int v = read(), num = read();
for(int j = 0; j <= sum; ++j)
for(int k = 0; k <= num; ++k) f[1][j + k * v] += f[0][j];
sum += v * num;
for(int j = 0; j <= sum; ++j) f[0][j] = f[1][j], f[1][j] = 0;
}
for(int i = (sum >> 1); i >= 0; --i) if(f[0][i])
{
write(sum - i), space, write(i), enter;
break;
}
}
return 0;
}