LOJ 575. 不等关系
给定一个长度为 (n-1) 的字符串 (S) ,仅包含 < 和 > 两种字符。
你需要计算「使得 (p_i<p_{i+1}) 当且仅当 (S) 为 < 的排列 (P)」的数量。
可以发现,答案可能很大,因此你只要输出它对 (10^9+7) 取模的结果。
换而言之,对于 (S) 中的每个位置,若 (S_i=) <,那么 (p_i<p_{i+1}),否则 (p_i> p_{i+1})
- (nle 10^5)
Solution
从 (2^n) 的状压的角度考虑,对于一个位置,如果为 <
那么设为 0 否则设为 1 。
对 <
和 >
进行容斥,我们忽略 >
的限制,这样计算出来的答案相当于答案的高维前缀和。
假设已知了高维前缀和数组,想要得到 (S) 处的点值,只需要做高维前缀差分,此时其一个子集对答案的贡献为其中 ((-1)) 的 (1) 的数量差次幂。
考虑形式化的描述这个问题,我们现在的问题即:
将序列默认为全 <
,部分原本为 >
位置要么无限制,要么为 <
,一种情况对答案的贡献为方案数乘以 (-1) 的相应次幂。
考虑每一段连续的 <
,无限制相当于进行了分割。
不难发现我们需要考虑的仅是长度为 (n) 的序列拆成若干个单调序列的方案数。
使用 Dp 来统计这个答案,我们从前往后 dp,每次加入一个 >
时有两种决策:
- 无限制
- 为
<
这样我们需要记录结尾段的长度,复杂度为 (mathcal O(n^2))
考虑得到划分后是否可以直接计算出答案,不难发现我们的答案等价于将 (n) 个元素划分成若干个集合的方案数并配以 ((-1)) 作贡献,即:
更加正式的,将 <
的下标记录下来,不难发现等价于:
此递推式采用分治 NTT 优化即可。
不是很懂 (50) 和 (100) 的差别有啥,板子么?这个给分着实迷惑。不过我感觉自己已经 8 万年没有写过分治乘了,结果今天没怎么调就过了,还是十分感动的。
(Code:)
#include<bits/stdc++.h>
using namespace std ;
#define Next( i, x ) for( register int i = head[x]; i; i = e[i].next )
#define rep( i, s, t ) for( register int i = (s); i <= (t); ++ i )
#define drep( i, s, t ) for( register int i = (t); i >= (s); -- i )
#define re register
#define int long long
int gi() {
char cc = getchar() ; int cn = 0, flus = 1 ;
while( cc < '0' || cc > '9' ) { if( cc == '-' ) flus = - flus ; cc = getchar() ; }
while( cc >= '0' && cc <= '9' ) cn = cn * 10 + cc - '0', cc = getchar() ;
return cn * flus ;
}
const int P = 998244353 ;
const int Gi = 332748118 ;
const int G = 3 ;
const int N = 4e5 + 5 ;
int fpow(int x, int k) {
int ans = 1, base = x ;
while(k) {
if(k & 1) ans = 1ll * ans * base % P ;
base = 1ll * base * base % P, k >>= 1 ;
} return ans ;
}
int n, p[N], cnt, limit, Inv, A[N], B[N], L, R[N], Ans, fac[N], inv[N], f[N] ;
char s[N] ;
void init(int x) {
limit = 1, L = 0 ; while( limit < x ) limit <<= 1, ++ L ;
for(re int i = 0; i < limit; ++ i) R[i] = ( (R[i >> 1] >> 1) | ((i & 1) << (L - 1)) ) ;
Inv = fpow( limit, P - 2 ) ;
}
void NTT( int *a, int type ) {
for(re int i = 0; i < limit; ++ i) if( R[i] > i ) swap( a[i], a[R[i]] ) ;
for(re int k = 1; k < limit; k <<= 1) {
int d = fpow( (type == 1) ? G : Gi, (P - 1) / (k << 1) ) ;
for(re int i = 0; i < limit; i += (k << 1) )
for(re int j = i, g = 1; j < i + k; ++ j, g = g * d % P) {
int nx = a[j], ny = a[j + k] * g % P ;
a[j] = (nx + ny) % P, a[j + k] = (nx - ny + P) % P ;
}
} if( !type ) for(re int i = 0; i < limit; ++ i) a[i] = a[i] * Inv % P ;
}
int st[N], top ;
void CDQ(int l, int r) {
if( l == r ) { if(l == 1) f[l] = 1 ; return ; }
int mid = (l + r) >> 1 ; CDQ(l, mid) ; top = 0 ;
for(re int i = l; i <= mid; ++ i) st[++ top] = f[i] ;
for(re int i = 0; i <= top; ++ i) A[i] = st[i] ;
for(re int i = 0; i <= top * 2; ++ i) B[i] = inv[i] ;
init(top * 3 + 5), NTT( A, 1 ), NTT( B, 1 ) ;
for(re int i = 0; i < limit; ++ i) A[i] = A[i] * B[i] % P ;
NTT( A, 0 ) ;
for(re int i = mid + 1; i <= r; ++ i)
if( p[i] ) f[i] = ( f[i] - A[i - mid + top] + P ) % P ;
for(re int i = 0; i <= limit; ++ i) A[i] = B[i] = 0 ;
CDQ(mid + 1, r) ;
}
signed main()
{
fac[0] = inv[0] = 1 ;
scanf("%s", s + 1 ), n = strlen(s + 1) + 1 ;
rep( i, 2, n ) p[i] = ( s[i - 1] == '>' ) ? 1 : 0 ;
rep( i, 1, n ) cnt += p[i] ;
rep( i, 1, n ) fac[i] = fac[i - 1] * i % P, inv[i] = fpow( fac[i], P - 2 ) ;
p[n + 1] = 1, CDQ(1, n + 1) ; int Ans = 0 ;
if( cnt & 1 ) Ans = f[n + 1] ;
else Ans = P - f[n + 1] ;
Ans = Ans * fac[n] % P ;
cout << Ans % P << endl ;
return 0 ;
}