最近和队友打了去年省赛的题。。有点自闭。。。1h时开始看这题,结束了还是没搞出来。。。
dp的部分能搞出来,但是最后实在想不出答案该怎么算出来。。。看了大神的博客恍然大悟
我们先考虑把一个长为\(i\)的链切成\(j\)段得到的所有情况的结果之和,用\(dp[i][j]\)存下来。
很容易就能想到一个\(O(n^3)\)的状态转移方程:
一眼看上去感觉能打前缀和优化。。。的确能
\(dp[i+1][j] = \sum_{k=1}^{i-j+2}{k*dp[i-k+1][j-1]}\)
\(=\sum_{k=0}^{i-j+1}{(k+1)*dp[i-k][j-1]}\)
\(=dp[i][j-1]+\sum_{k=1}^{i-j+1}{(k+1)*dp[i-k][j-1]}\)
\(=dp[i][j-1]+\sum_{k=1}^{i-j+1}{k*dp[i-k][j-1]}+\sum_{k=1}^{i-j+1}dp[i-k][j-1]\)
\(=dp[i][j-1]+dp[i][j]+\sum_{k=1}^{i-j+1}dp[i-k][j-1]\)
\(=dp[i][j]+\sum_{k=0}^{i-j+1}dp[i-k][j-1]\)
\(=dp[i][j]+\sum_{k=j-1}^{i}dp[k][j-1]\)
用\(sum[i][j]代替\sum_{k=j-1}^{i}dp[k][j-1]\),就有了\(O(n^2)\)的转移方程
接下来我们枚举一个\(i\),考虑把一个长为\(m\)的白色序列切成\(i\)段的结果之和,很显然就是\(dp[m][i]\).切\(i\)段要用\(i\)刀(本来是\(i-1\)刀,因为把一个环切成链需要1刀,这刀在这条白链的第一个点前面或最后一个点后面,所以一共\(i\)刀),
回到原题每刀就代表着一个黑色序列;反过来就相当于把一个长为\(n\)的黑色序列切了\(i\)刀,所以再乘一个\(dp[n][i]\).
-----------------------------------------------------------------------到这边我都会做,下面我就实在是想不到了。。。。------------------------------------------------------------
现在我们有了一条链的答案,如何把它转换成一个环呢?
比如这条链是OXOOXOOOX(O是白,X是黑)
最直观的想法就是有\((n+m)\)种填编号的方法,所以答案再乘个\((n+m)\),但是这样显然是有重复的。。。等价的情况没有考虑。。。
两个方面:一是如果是旋转等价的OOXOOXOOX,会有重复;二是事实上在dp的过程中,OXOOXOOOX和OOXOOOXOX和OOOXOXOOX都算过了,我们如果直接乘个\((n+m)\)也是会重复的
为简单起见,我们现在忽略黑色段的长度,只看这\(i\)个白色段的长度序列,如OXOOXOOOXOXOOXOOOX我们记作1 2 3 1 2 3
回顾刚刚提到的这两种可能出现重复的情况,我们会发现对这个例子,第一种旋转等价出现了两次的重复(这两次是在乘\((n+m)\)时被重复计算的),第二种等价出现了三次的重复(这三次是在dp里被重复计算的),乘起来刚好是6次。
进一步地,我们可以发现对任何计入结果的白色段数为\(i\)的序列,它都被重复计算了\(i\)次,因此之前的结果除以\(i\),就是最后的答案。。。。。。
考虑回黑色段的长度,结论类似。
(以上都是本人脑补,并未经过严谨的数学证明,如有错误请指出)
另外直接开两个5000*5000的数组会MLE,滚动一下就好了
第一次写博客,请大家多多指教
#include <bits/stdc++.h>
#define pii pair<int,int>
#define pid pair<int,double>
#define LL long long
#define MAXN 100000
using namespace std;
const LL mod = 1000000007;
LL dp[5008][5008];
//LL dp2[5008][5008];
LL sum[2][5008];
LL inv[5008];
LL qp(LL a, LL b)
{
LL ret = 1;
while (b)
{
if (b & 1)
ret = ret * a % mod;
a = a * a % mod;
b >>= 1;
}
return ret;
}
void gao()
{
dp[0][0] = 1;
sum[0][0] = 1;
int cc = 0;
for (int i = 1; i <= 5000; ++i)
{
inv[i] = qp(i, mod - 2);
dp[i][0] = sum[cc^1][0] = 1;
for (int j = 1; j <= i; ++j)
{
/*dp[i][j] = dp[i - 1][j];
dp[i][j] %= mod;
for (int k = j - 1; k <= i - 1; ++k)
{
dp[i][j] += dp[k][j - 1];
dp[i][j] %= mod;
}*/
dp[i][j] = dp[i - 1][j] + sum[cc][j - 1];
dp[i][j] %= mod;
sum[cc^1][j] = sum[cc][j] + dp[i][j];
sum[cc^1][j] %= mod;
}
cc ^= 1;
}
/*dp2[0][0] = 1;
for (int i = 1; i <= 50; ++i)
{
for (int j = 1; j <= i; ++j)
{
for (int k = 1; k <= i - j + 1; ++k)
{
dp2[i][j] += k * dp2[i - k][j - 1];
dp2[i][j] %= mod;
}
}
}*/
}
LL n, m;
int main()
{
gao();
while (~scanf("%lld %lld", &n, &m))
{
LL ans = 0;
for (int i = 1; i <= n && i <= m; ++i)
{
ans += dp[n][i] * dp[m][i] % mod * inv[i] % mod;
ans %= mod;
}
ans *= (n + m);
ans %= mod;
printf("%lld\n", ans);
}
return 0;
}
//265ms 118340kB