CF1327F AND Segments
题意翻译
题目描述
你有三个整数 n, k, mn,k,m 以及 mm 个限制 (l_1, r_1, x_1), (l_2, r_2, x_2), ldots, (l_m, r_m, x_m)(l1,r1,x1),(l2,r2,x2),…,(l**m,r**m,x**m)。
计算满足下列条件的,长度为 nn 的序列 aa 的个数:
- 对于每个 1 le i le n1≤i≤n,0 le a_i lt 2 ^ k0≤a**i<2k。
- 对于每个 1 le i le m1≤i≤m,数字的按位与 a[l_i] ext{ and } a_[l_i + 1] ext{ and } ldots ext{ and } a[r_i] = x_ia[l**i] and a[l**i+1] and … and a[r**i]=x**i。
两个序列 a, ba,b 被认为是不同的,当且仅当存在一个位置 ii 满足 a_i eq b_ia**i�=b**i。
由于答案可能过大,请输出其对 998 244 353998 244 353 取模的结果。
输入格式
第一行输入三个整数 n, k, m ~(1 le n le 5 cdot 10 ^ 5; 1 le k le 30; 0 le m le 5 cdot 10 ^ 5)~n,k,m (1≤n≤5⋅105;1≤k≤30;0≤m≤5⋅105) ,分别表示数组 aa 的长度,aa 中元素的值域,以及限制的个数。
接下来 mm 行,每行描述一个限制 l_i, r_i, x_i ~ (1 le l_i le r_i le n; 0 le x_i lt 2 ^ k)l**i,r**i,x**i (1≤l**i≤r**i≤n;0≤x**i<2k),分别表示限制的线段区间以及按位与值。
输出格式
输出一行一个整数,表示满足条件的序列 aa 的个数,对 998 244 353998 244 353 取模的结果。
题解:
对于这种题目,一般都是从位运算按位独立这里入手。这题也不例外。
既然位运算相互独立,且本题只有与运算,那么很容易得出这样的一个性质:对于某个约束区间([l,r])的约束值(x),对(x)二进制拆分,(x)的某位如果为1,那么这个区间的数的对应位全部为1。反之如果为0,那么这个区间的数的对应位至少有一个为0.
推到此处狂喜,然后觉得可以用排列组合来做,或者是个容斥什么的。但是这是不现实的,因为区间的并实在太难处理了。
但是想到排列组合却能给我们一个启发。那就是计数原理的乘法原理。既然这些位都是互相独立的,那么每位的合法方案的方案数乘到一起就是答案。
每位的合法方案数......既然排列组合容斥啥的都没戏,那就只能考虑DP,毕竟对于计数类问题我们也并没有什么其他的方法。这个不行就试那个。
设(dp[i])表示处理到前i个数,且第(i)个数的对应位为0的方案数。不需要再开一维来存第几位,每次做的时候清空即可。
那么很显然,如果某个约束条件使得一个区间内的数字的对应位为1,那么这段区间的dp值应该都为0.
这个0可以用多种方法维护,我选择的是差分。
如果某个约束条件使得一个区间内的数字的对应位至少有一个为0,这种情况该怎么办呢?
分类讨论:首先如果某个区间的右端点比当前枚举到的点大,那么这个条件已经被满足,不用考虑。
如果某个区间([l,r])的右端点比当前枚举到的点(i)小,那么这个条件要被满足,一定会在这里填一个0,所以就可以转移:
为什么要到(i-1)而不是到(r)呢?因为我们统计的(dp[i])已经被叫做合法方案,所以在这个区间的都合法,也就都要统计进来。
那么我们敏锐地发现,这样统计会导致很多区间被重复算进去。那么我们为了不重复统计,就采用一个(lmx[i])数组,表示在(i)左边最大的(l)值。这样的话每次转移只会从最大的开始累加,而前面的已经在前面的转移被累加过了。这样就完成了这道题。
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=5e5+5;
const int mod=998244353;
int n,k,m;
int l[maxn],r[maxn],x[maxn];
int lmx[maxn],dp[maxn],cf[maxn];
int calc(int a)
{
memset(lmx,0,sizeof(lmx));
memset(cf,0,sizeof(cf));
memset(dp,0,sizeof(dp));
for(int i=1;i<=m;i++)
{
if(x[i]>>a&1)
cf[l[i]]++,cf[r[i]+1]--;
else
lmx[r[i]]=max(lmx[r[i]],l[i]);
}
for(int i=1;i<=n+1;i++)
{
lmx[i]=max(lmx[i],lmx[i-1]);
cf[i]+=cf[i-1];
}
int sum=1,j=0;
dp[0]=1;
for(int i=1;i<=n+1;i++)
{
if(cf[i])
{
dp[i]=0;
continue;
}
while(j<lmx[i-1])
{
sum=(sum-dp[j]+mod)%mod;
j++;
}
dp[i]=sum;
sum=(sum+dp[i])%mod;
}
return dp[n+1];
}
int main()
{
scanf("%d%d%d",&n,&k,&m);
for(int i=1;i<=m;i++)
scanf("%d%d%d",&l[i],&r[i],&x[i]);
int ans=1;
for(int i=0;i<k;i++)
ans=1ll*ans*calc(i)%mod;
printf("%d
",ans);
return 0;
}