传送门:https://codeforces.com/gym/102956
题目大意:
统计满足下列条件的数列的方案数:
- 非空
- 严格递增
- 任意连续三个元素的异或和不为 (0)
- 元素小于等于给定的 (n)
分析:
(f[i]) 表示以 (i) 为尾的方案数。
考虑状态转移:
- 如果 (f[i]) 只有一个元素,自然属于一种方案。
- 如果 (i) 的上一个数是 (j ~ (j in [1,i-1]))
- (i oplus j >= j) 时,直接转移即可 (f[i]+=f[j])
- 否则, (f[i]+=f[j]-f[j oplus i])
写出更为一般化的公式就是:
(f[i]=sum_{u=1}^{i-1} f[u] - sum_{j=1}^{i-1} f[i oplus j]),这里的 (j) 满足 ((j > i oplus j))
我们记 (highbit(x)) (
是我乱起的)为二进制中 (x) 除了最高位为 (1) 其余全部清 (0) 所对应的数,例如:(highbit(11001)=10000)
注意到 (j) 与 (i) 必须位数相等(充分必要条件)。
因此,我们可以得到更为精确的 (j) 的范围 ([highbit(i),i-1])
记 (k=ioplus j) ,我们只需找到一个快速统计 (k) 的办法就好了。
举两个例子(二进制):
(i=10100,故~j=10000-10011) ,对应的 (k) 为 (100-111) (不一定是依次对应的)。
(i=101010,故~j=1000000-101001) ,
对应的 (k) 为 (1000-1111),以及 (10-11) (不一定是依次对应的)。
从中我们发现了规律:
记 (t=i-highbit(i)) ,然后每一次取 (t) 的 (highbit) ,区间 ([highbit(t),highbit(2t-1)]) 都是满足的,我们依次统计就好了。
证明挺简单的,就留作习题吧。
我自己乱搞了一个求 (highbit(x)) 的算法:
就是求 (2^{lfloor log_2x floor}) 即可。
如果感觉不太好懂可以用下面的打表程序试试:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+6;
int highbit[N];
void init(){
for(int i=2;i<N;i++){
highbit[i]=highbit[i/2]+1;
}
for(int i=2;i<N;i++)
highbit[i]=1LL<<highbit[i];
highbit[1]=1;
}
void get(int x){
vector<int> v;
while(x) v.push_back(x%2), x>>=1;
reverse(v.begin(),v.end());
for(auto i:v) cout<<i;
cout<<' ';
}
int main(){
init();
int i; cin>>i;
for(int j=highbit[i];j<=i-1;j++)
get(j^i), puts("");
return 0;
}
原题代码:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+6;
const int mod=998244353;
long long f[N],s[N];
int highbit[N];
void init(){
for(int i=2;i<N;i++){
highbit[i]=highbit[i/2]+1;
}
for(int i=2;i<N;i++)
highbit[i]=1LL<<highbit[i];
highbit[1]=1;
}
int main(){
init();
int n; cin>>n;
cerr<<highbit[N-1];
for(int i=1;i<=n;i++){
auto &v=f[i];
v=1;
v=(v+s[i-1])%mod;
int t=i;
t-=highbit[t];
while(t){
int x=highbit[t];
v=(v-(s[(x<<1)-1]-s[x-1])+mod)%mod;
t-=x;
}
s[i]=(s[i-1]+v)%mod;
}
cout<<s[n]<<endl;
return 0;
}