幂次序列 哈希+启发式合并
题目描述
分析
我们先不考虑精度问题
暴力的思想是对于每一个点(i),向前找是否存在一个点(j),使得(sum[i]-sum[j-1]=2^k)
我们考虑优化这个暴力
对于一段长度为(k)的区间,我们可以找到这个区间中的最大的数(a[i])
而区间的和的指数一定不会超过(2^{log_k+a[i]})
因此我们可以把这个区间分成两半,一半在最大的数的左边,另一半在最大的数的右边
我们用指针在长度较短的区间从左到右扫一遍,同时在长度较长的区间的哈希表中找和较小的部分拼起来的和恰好为(2)的幂的值的个数
因为前缀和太大无法直接存,因此我们取前缀和对一个质数取模后的结果
代码
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#include<cmath>
inline int read(){
int x=0,fh=1;
char ch=getchar();
while(ch<'0' || ch>'9'){
if(ch=='-') fh=-1;
ch=getchar();
}
while(ch>='0' && ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x*fh;
}
const long long mod=1926081719491001LL;
//社会主义加成
const int maxn=1e6+5;
std::vector<int> g[maxn];
struct m_hash{
static const int Mod=261807;
struct asd{
int next;
long long val;
}b[maxn];
int head[maxn],tot;
m_hash(){
memset(head,-1,sizeof(head));
tot=1;
}
void ad(long long val){
int now=1LL*val%Mod;
for(int i=head[now];i!=-1;i=b[i].next){
long long u=b[i].val;
if(u==val) return;
}
b[tot].val=val;
b[tot].next=head[now];
head[now]=tot++;
}
int cx(long long val){
int now=1LL*val%Mod;
for(int i=head[now];i!=-1;i=b[i].next){
long long u=b[i].val;
if(u==val) return i;
}
return 0;
}
}mp;
//哈希表
int n,a[maxn];
long long sum[maxn];
long long gsc(long long aa,long long bb){
long long z=(long double)aa/mod*bb;
long long ans=(unsigned long long)aa*bb-(unsigned long long)z*mod;
return (ans+mod)%mod;
}
//O(1)光速乘
long long ksm(long long ds,long long zs){
long long ans=1;
while(zs){
if(zs&1LL) ans=gsc(ans,ds)%mod;
ds=gsc(ds,ds)%mod;
zs>>=1LL;
}
return ans;
}
//快速幂
int wz[maxn][22];
int get_val(long long val,int l,int r){
if(!mp.cx(val)) return 0;
int id=mp.cx(val);
return std::upper_bound(g[id].begin(),g[id].end(),r)-std::lower_bound(g[id].begin(),g[id].end(),l);
}
//取在区间[l,r]中值恰好为val的区间个数
int get_wz(int l,int r){
int k=log2(r-l+1);
if(a[wz[l][k]]<a[wz[r-(1<<k)+1][k]]) return wz[r-(1<<k)+1][k];
else return wz[l][k];
}
//找到区间最大值在哪里
int ans=0;
void solve(int l,int r){
if(l>r) return;
if(l==r){
ans++;
return;
}
int mids=get_wz(l,r);
solve(l,mids-1);
solve(mids+1,r);
long long max_val=ksm(2,1LL*a[mids]);
int zqj=mids-l,yqj=r-mids;
for(int cs=1;cs<=20;cs++,max_val=max_val*2%mod){
if(zqj<yqj){
for(int i=l;i<=mids;i++){
ans+=get_val((max_val+sum[i-1])%mod,mids,r);
}
} else {
for(int i=mids;i<=r;i++){
ans+=get_val((sum[i]-max_val+mod)%mod,l-1,mids-1);
}
}
}
}
//启发式合并
int main(){
n=read();
for(int i=1;i<=n;i++){
a[i]=read();
wz[i][0]=i;
}
mp.ad(0);
g[mp.tot-1].push_back(0);
for(int i=1;i<=n;i++){
sum[i]=sum[i-1]+ksm(2,1LL*a[i]);
sum[i]%=mod;
if(mp.cx(sum[i])) g[mp.cx(sum[i])].push_back(i);
else {
mp.ad(sum[i]);
g[mp.tot-1].push_back(i);
}
}
for(int j=1;j<=20;j++){
for(int i=1;i+(1<<j)-1<=n;i++){
if(a[wz[i][j-1]]<a[wz[i+(1<<(j-1))][j-1]]) wz[i][j]=wz[i+(1<<(j-1))][j-1];
else wz[i][j]=wz[i][j-1];
}
}
solve(1,n);
printf("%d
",ans);
return 0;
}