AGC021E Ball Eat Chameleons
题目大意: 有 (n) 条变色龙,一开始都为蓝色,现在要求合法的长度为 (k) 的 RB 序列数量。对于一个序列,会按照序列顺序丢对应颜色的球,每一个球有随机一只变色龙吃下。如果这条变色龙吃的当前颜色球数量大于另一种颜色球数量,就会变色。求有多少序列,使得最终所有变色龙有可能变为红色。
数据范围:(1leq n,kleq 5 imes 10^5) 。
解题思路:考虑给定一个序列,怎样做才尽可能让所有变色龙变成红色。注意到红球和蓝球可以相互抵消,而让一个变色龙变色需要额外一个球的代价。那么对于一个红色的变色龙,如果没有吃过蓝球,可以免费吃一个,如果没有这样的变色龙,所有的蓝球都会丢到同一个变色龙上,因为丢在多个会付出多份代价。所以可以得到一个正确的策略,首先空出1号变色龙,对于一个红球,如果还有未丢过球的变色龙,就让它变色,否则丢给1号变色龙。而对于一个蓝球,如果有可以免费抵消的红色变色龙,则丢给它,否则丢给1号变色龙。
设红球有 (R) 个,蓝球有 (B) 个,这样子的策略1号变色龙吃到的红球数为 (t=R-(n-1)) ,此时如果 (R < B) ,会发现一号变色龙吃到的蓝球数量最少为 (B-(n-1)) 无解。
考虑当红球放了 (n-1) 个之前,1号变色龙吃到蓝球当且仅当没有红球可以免费抵消了,也就是蓝球匹配前面的红球无法匹配,记蓝球权值为 (1),红球权值为 (-1),那么 1号变色龙在红球放了 (n-1) 个之前吃到的蓝球数量为权值前缀和的 (max) 。也就是说在这之前权值前缀和的 (max) 必须 (<t) 。
考虑红球放了 (n-1) 个之后的情况,如果 (R=B) ,设之前吃到的蓝球数为 (s) ,显然有 (s < t =B -(n-1)) ,也就是剩下的蓝球数大于红球数,那么1号变色龙吃到的蓝球数一定为 (t) 个。由于要保证变色,最后一个球只能是蓝球,所以方案数就是任意时刻前缀 (max < t) 且最后一个球是蓝球的方案数,即 ({k-1choose R}-{k-1choose R+t}) 。
如果 (R>B) ,如果剩下的蓝球数能完全被前 ((n-1)) 个红球抵消,那么显然是一个合法的序列,而这个序列是一定满足 (max < t) 的,否则说明这 ((n-1)) 个红球中的每一个都抵消了一个蓝球,而 (R-(n-1)>B-(n-1)) 一定是一个合法的序列,此时倒推可知满足 (max < t) ,所以在这种情况下 (max < t) 是合法的充要条件,方案数就是 ({kchoose R}-{kchoose R+t})。
方案数推导,方案数可以看做从 ((0,0)) 走到 ((N,M)) 且不能碰到直线 (y-x=L) 的方案数,那么对于一个不合法的方案,找到第一次碰到直线 (y-x=L) 的位置,并将之前的部分沿直线翻折,会等价于一个从 ((-L,L)) 走到 ((N,M)) 的方案数,所以就是两个组合数相减 ({N+Mchoose N}-{N+Mchoose N+L}) 。
code
/*program by mangoyang*/
#pragma GCC optimize("Ofast", "inline")
#include<bits/stdc++.h>
#define inf (0x3f3f3f3f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 1000005, mod = 998244353;
int js[N], inv[N], n, k;
inline int Pow(int a, int b){
int ans = 1;
for(; b; b >>= 1, a = 1ll * a * a % mod)
if(b & 1) ans = 1ll * ans * a % mod;
return ans;
}
inline int C(int x, int y){
if(x < 0 || y < 0 || x < y) return 0;
return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod;
}
inline void up(int &x, int y){
x = x + y >= mod ? x + y - mod : x + y;
}
int main(){
js[0] = 1, inv[0] = 1;
for(int i = 1; i < N; i++){
js[i] = 1ll * js[i-1] * i % mod;
inv[i] = Pow(js[i], mod - 2);
}
read(n), read(k);
int ans = 0;
for(int i = n; i <= k; i++)
if(i > k - i){
int t = i - n + 1;
up(ans, C(k, i));
up(ans, mod - C(k, i + t));
}
else if(i == k - i){
int t = i - n + 1;
up(ans, C(k - 1, i));
up(ans, mod - C(k - 1 , i + t));
}
cout << ans << endl;
return 0;
}