Description
给一个长度为(n)的(a)数组与一个长度为(m)的(b)数组,求把(a)数组划分为(m)段使得对每个(i)都有第(i)段最小值为(b_i)的方案数((mod) (998244353))
Solution
设(f[i])表示(a)数组中划分到第(i)位(只考虑(a_i)与(b)中某元素相等的(i)),(a_i)与(b_k)相等,(a_i)为第(k)段最小值。
转移是(f[i]+=f[j]*calc(j,i)),(a_j=b_{k-1}),(calc(j,i))计算的是将([j+1,i-1])中的元素分为两部分,满足前一部分属于第(k-1)段,后一部分属于第(k)段的可行方案数(仍保证第(k-1)段最小值为(a_j),第(k)段最小值为(a_i))
发现方程中的(j)位置只需要取最靠后的满足(a_j=b_{k-1})的即可,因为在把([j+1,i-1])分为两半时,若中间有位置(m)满足(a_m=a_j),由于(a_m=a_j<a_i),位置(m)必被归类于第(k-1)段,那完全可以直接用位置(m)进行转移
(好吧我承认思路有那么一点点奇怪(讲得似乎也有那么一点点奇怪),对比正解存在一定冗余。正解好像是(O(n)),我做法中离散化、预处理(ST)表都为(nlogn),计算(calc)用的是倍增或二分,计算一次是(logn)的复杂度,总复杂度(O(nlogn))。)
Code
#include <bits/stdc++.h>
#define Mod 998244353
using namespace std;
typedef long long ll;
inline int read() {
int out = 0;
bool flag = false;
register char cc = getchar();
while (cc < '0' || cc > '9') {
if (cc == '-') flag = true;
cc = getchar();
}
while (cc >= '0' && cc <= '9') {
out = (out << 3) + (out << 1) + (cc ^ 48);
cc = getchar();
}
return flag ? -out : out;
}
inline void write(int x) {
if (x < 0) putchar('-'), x = -x;
if (x == 0) putchar('0');
else {
int num = 0;
char cc[20];
while (x) cc[++num] = x % 10 + 48, x /= 10;
while (num) putchar(cc[num--]);
}
putchar(' ');
}
int n, m, a[200010], b[200010], c[400010], pre[200010], lst[200010], tot, Log[200010], Min[20][200010], f[200010];
inline int MIN(const int &l, const int &r) {
int t = Log[r - l + 1];
if (Min[t][l] < Min[t][r - (1 << t) + 1]) return Min[t][l];
else return Min[t][r - (1 << t) + 1];
}
inline int calc(const int &l, const int &r) {
int x = l, y = r;
for (int i = 18; i >= 0; i--)
if (x + (1 << i) < r && MIN(l, x + (1 << i)) >= a[l]) x += 1 << i;
for (int i = 18; i >= 0; i--)
if (y - (1 << i) > l && MIN(y - (1 << i), r) >= a[r]) y -= 1 << i;
if (x < y - 1) return 0;
return x - y + 2;
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; i++) c[++tot] = a[i] = read();
for (int i = 1; i <= m; i++) c[++tot] = b[i] = read();
sort(c + 1, c + tot + 1);
tot = unique(c + 1, c + tot + 1) - c - 1;
for (int i = 1; i <= n; i++) a[i] = lower_bound(c + 1, c + tot + 1, a[i]) - c;
for (int i = 1; i <= m; i++) b[i] = lower_bound(c + 1, c + tot + 1, b[i]) - c;
for (int i = 2; i <= n; i++) Log[i] = Log[i >> 1] + 1;
for (int i = 1; i <= n; i++) Min[0][i] = a[i];
for (int k = 1; (1 << k) <= n; k++) {
for (int i = 1; i + (1 << k) - 1 <= n; i++) {
if (Min[k - 1][i] < Min[k - 1][i + (1 << (k - 1))]) Min[k][i] = Min[k - 1][i];
else Min[k][i] = Min[k - 1][i + (1 << (k - 1))];
}
}
for (int i = 1; i <= m; i++) lst[b[i]] = b[i - 1];
int o = INT_MAX;
for (int i = 1; i <= n; i++) {
o = min(o, a[i]);
if (a[i] == b[1] && o == a[i]) f[i] = 1;
}
for (int i = 1; i <= n; i++) {
//cout << lst[a[i]] << ' ' << pre[lst[a[i]]] << endl;
if (!f[i]) f[i] = 1ll * f[pre[lst[a[i]]]] * calc(pre[lst[a[i]]], i) % Mod;
//cout << f[i] << endl;
pre[a[i]] = i;
}
int ans = 0;
for (int i = 1; i <= n; i++) if (a[i] == b[m] && MIN(i, n) >= a[i]) {
ans = f[i];
//if (ans >= Mod) ans -= Mod;
}
cout << ans << endl;
return 0;
}