【题目大意】
一个合法的引号序列是空串;如果引号序列合法,那么在两边加上同一个引号也合法;或是把两个合法的引号序列拼起来也是合法的。
求长度为$n$,字符集大小为$k$的合法引号序列的个数。多组数据。
$1 leq T leq 10^5, 1 leq n leq 10^7, 1leq k leq 10^9$
【题解】
显然引号序列可以看做括号序列,于是我们有了一个$O(n^2)$的dp了。
设$f_{i,j}$表示到第$i$个位置,前面有$j$个左引号没有匹配,的方案数
每次,要么有1种方案匹配前面的某一个引号,要么有$(k-1)$种方案开启一个新的左引号。
特别地,当$j=0$的时候,只能开启新的左引号,有$k$种方案。
就是当$jgeq 1$时:
$f_{i+1,j+1} = f_{i+1,j+1} + (k-1)f_{i,j}$
$f_{i+1,j-1} = f_{i+1,j-1} + f_{i,j}$
特别地,当$j=1$时:
$f_{i+1,j+1} = f_{i+1,j+1} + kf_{i,j}$
于是这是一个优秀的$O(n^2)$做法。
考虑如何优化,这里我们不讨论关于生成函数、暴力解方程等方法。
生成函数 大力化简 详见 https://chrt.github.io/2017/07/04/oeis-a183135/
考虑一种优秀的做法:
转化模型:等价于,我要在数轴上从0开始走$n$步,每次可以向正方向走、向负方向走,不能走到负半轴。当不在原点的时候,向正方向有$(k-1)$种方法,向负方向有1种方法;在原点的时候,向正方向有$k$种方法。最后回到原点的方案数。
由于$k = (k-1) + 1$,我们就可以把那种往上的方案对应成往下的,有如下转化:
我们定义一线表示实际的括号序列;二线表示每个一线对应的另外一种括号序列,见下。
比如一线的括号序列是左括号+1,右括号-1的折线;二线的括号序列就是右括号-1,左括号在0的时候-1,其他+1的折线。
更通俗的说,一线的括号序列是正常的一个括号序列;二线的括号序列是把一线括号序列中,每次从0开始,选择$k$种方法中的一种,走到1的这个左括号,人为看做右括号(因为它并没有贡献$(k-1)$的方案)。
那么二线对应着一种一线的括号序列。二线的左括号个数$i$,也就是实际需要乘$(k-1)$的个数(由于其他的左括号,是因为碰到了0,在$k$种方案中有1种方案向上,我们选择了那种方案导致)。
考虑有多少种二线有$i$个左括号的括号序列,对应到一线中是合法的。
二线有$i$个左括号,那么二线的终止位置是$-2(n-i)$;如果二线走到了$-2(n-i)-1$,相当于我一线从0用$k$种方法的1种方法,走到了1,这个一线实际上对应的二线方案应该只有$i-1$个左括号(因为这个左括号是没有用的,不应该被乘$(k-1)$)。
好需要证明每个走到$-2(n-i)-1$的二线有$i$个括号的方案和二线有$i-1$个括号的方案是一一对应的。这个其实显然,我走到了$-2(n-i)-1$的点后,把后面的括号翻转,就对应于一个有$i-1$的括号序列了。类似于卡特兰数的证明。
所以总的方案就是${2nchoose i} - {2nchoose i-1}$。
我们特殊定义${x choose -1} = 0$。
然后答案就是$sum_{i=0}^{n} ({2nchoose i} - {2nchoose i-1})(k-1)^i$
首先减法可以分开处理,我们只要处理$sum_{i=0}^{n}{2nchoose i}(k-1)^i$的线性递推问题即可。
这个我们可以用广义杨辉三角形来解决线性递推问题。
考虑杨辉三角形的构造
相当于把上一行复制一遍,(乘上对应系数1),移到下一行,右移一位,两两相加。
e.g
那么这个给我们创造了一个非常好的线性递推的思路,我们可以将上一行乘以$(k-1)$,移到下一行,右移一位,两两相加。
正确性非常明显,上一行第$i$个数,乘了$(k-1)$,相当于第$i+1$列,也就是下一行的第$i+1$个数。
然后我们发现我们求得是每两行的广义杨辉三角的前一半的和。
我们从偶数行的一半乘以$(k-1)+1$(包括自己,因为上面是一行+另一行乘$(k-1)$),减去边界,推到奇数行;
再从奇数行的一半乘以$(k-1)+1$,减去边界,推到偶数行,注意边界问题。
最好是自己模拟下k=3的情况,然后就能知道边界是什么了。
注意0的情况答案是1.
可能卡卡常就过了?
# include <stdio.h> # include <string.h> # include <iostream> # include <algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef long double ld; const int N = 1e7 + 5, M = 2e7 + 5, F = 1e7, FM = 2e7; const int mod = 1e9 + 7; int n, K, A; int f[N], g[N], p[N], c[N]; int fac[M], inv[M]; inline int pwr(int a, int b) { int ret = 1; while(b) { if(b&1) ret = 1ll * ret * a % mod; a = 1ll * a * a % mod; b >>= 1; } return ret; } inline int C(int n, int k) { if(n < k) return 0; return 1ll * fac[n] * inv[k] % mod * inv[n-k] % mod; } int main() { // freopen("quote.in", "r", stdin); // freopen("quote.out", "w", stdout); int T; cin >> T >> K; A = K--; if(!K) { while(T--) scanf("%d", &n), puts("1"); return 0; } fac[0] = inv[0] = p[0] = 1; for (int i=1; i<=FM; ++i) fac[i] = 1ll * fac[i-1] * i % mod; inv[FM] = pwr(fac[FM], mod-2); for (int i=FM-1; i; --i) inv[i] = 1ll * inv[i+1] * (i+1) % mod; for (int i=1; i<=F; ++i) p[i] = 1ll * p[i-1] * K % mod; f[1] = K + K + 1; if(f[1] >= mod) f[1] -= mod; for (int i=2, t; i<=F; ++i) { t = 1ll * A * f[i-1] % mod; t = t + 1ll * C(i*2-2, i) * p[i] % mod; if(t >= mod) t -= mod; t = 1ll * A * t % mod; t = t - 1ll * K * C(i*2-1, i) % mod * p[i] % mod; if(t < 0) t += mod; f[i] = t; } for (int i=1, t; i<=F; ++i) { t = f[i] - 1ll * C(i*2, i) * p[i] % mod; g[i] = 1ll * t * K % mod; if(g[i] < 0) g[i] += mod; f[i] -= g[i]; if(f[i] < 0) f[i] += mod; } f[0] = 1; while(T--) { scanf("%d", &n); printf("%d ", f[n]); } return 0; }
upd: 本题卡常技巧
减少取模次数,数组大可以开static(迷)
upd2: 卡了波常,然后过了
# include <stdio.h> # include <string.h> # include <iostream> # include <algorithm> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef long double ld; const int N = 1e7 + 5, M = 2e7 + 5, F = 1e7, FM = 2e7; const int mod = 1e9 + 7; inline int getint() { int x = 0; char ch = getchar(); while(!isdigit(ch)) ch = getchar(); while(isdigit(ch)) x = (x<<3) + (x<<1) + ch - '0', ch = getchar(); return x; } int n, K, A; int fac[M], inv[M], p[N], c[N]; inline int pwr(int a, int b) { int ret = 1; while(b) { if(b&1) ret = 1ll * ret * a % mod; a = 1ll * a * a % mod; b >>= 1; } return ret; } inline int C(int n, int k) { if(n < k) return 0; return 1ll * fac[n] * inv[k] % mod * inv[n-k] % mod; } int main() { // freopen("quote.in", "r", stdin); // freopen("quote.out", "w", stdout); static int f[N], g[N]; int T; T = getint(); K = getint(); A = K--; if(!K) { while(T--) puts("1"); return 0; } fac[0] = inv[0] = p[0] = 1; for (int i=1; i<=FM; ++i) fac[i] = 1ll * fac[i-1] * i % mod; inv[FM] = pwr(fac[FM], mod-2); for (int i=FM-1; i; --i) inv[i] = 1ll * inv[i+1] * (i+1) % mod; for (int i=1; i<=F; ++i) p[i] = 1ll * p[i-1] * K % mod; f[1] = K + K + 1; if(f[1] >= mod) f[1] -= mod; for (int i=2, t; i<=F; ++i) { t = (1ll * A * f[i-1] + 1ll * C(i*2-2, i) * p[i]) % mod; t = (1ll * A * t - 1ll * K * C(i*2-1, i) % mod * p[i]) % mod; if(t < 0) t += mod; f[i] = t; } for (int i=1, t; i<=F; ++i) { t = f[i] - 1ll * C(i*2, i) * p[i] % mod; g[i] = 1ll * t * K % mod; if(g[i] < 0) g[i] += mod; f[i] -= g[i]; if(f[i] < 0) f[i] += mod; } f[0] = 1; while(T--) { n = getint(); printf("%d ", f[n]); } return 0; }