题意:
题解:
询问l~r的答案,很显然我们直接转化成0~r的答案减去0~(l-1)的答案。
如何求0~a的答案,从高位往低位枚举。
(1)如果当前位置在a中是1,那么这个位置可以取0,也可以取1。如果取0,那后面的随便取都无所谓,也就是说,后面的位置的0,1是可以随便乱放的,也就是说当前这一位为0^(x的这一位),然后之后位随便取,恰好是一个区间,可以求区间的答案。接下来把这一位取1^(x的这一位)然后继续往低位走。
(2)如果当前位置在a中是0,那么这个位置只能取0,就把这一位取0^(x的这一位),接着走。
接下来就是快速查询一段区间的答案,也就是一段区间的f(i)值。方法很多,可以二分直接做,然而我写了颗线段树维护。。。
#include<cstdio> #include<algorithm> #include<cstdlib> using namespace std; const int mod=998244353,INF=(1<<30)-1; int n,q,a[100002],cnt=1; typedef struct{ int ls,rs; long long sum,f; }P; P p[10000002]; void pushdown(int root,int begin,int mid,int end){ if (p[root].f) { if (!p[root].ls)p[root].ls=++cnt; if (!p[root].rs)p[root].rs=++cnt; p[p[root].ls].sum=(p[p[root].ls].sum+(mid-begin+1)*p[root].f)%mod; p[p[root].rs].sum=(p[p[root].rs].sum+(end-mid)*p[root].f)%mod; p[p[root].ls].f=(p[p[root].ls].f+p[root].f)%mod; p[p[root].rs].f=(p[p[root].rs].f+p[root].f)%mod; p[root].f=0; } } void gx(int root,int begin,int end,int begin2,int end2,long long z){ if (begin>=begin2 && end<=end2) { p[root].sum=(p[root].sum+z*(end-begin+1))%mod;p[root].f=(p[root].f+z)%mod; return; } int mid=(begin+end)/2;pushdown(root,begin,mid,end); if (!(begin>end2 || mid<begin2)) { if (!p[root].ls)p[root].ls=++cnt; gx(p[root].ls,begin,mid,begin2,end2,z); } if (!(mid+1>end2 || end<begin2)) { if (!p[root].rs)p[root].rs=++cnt; gx(p[root].rs,mid+1,end,begin2,end2,z); } p[root].sum=(p[p[root].ls].sum+p[p[root].rs].sum)%mod; } long long cx(int root,int begin,int end,int begin2,int end2){ if (begin>end2 || end<begin2 || !root)return 0; if (begin>=begin2 && end<=end2)return p[root].sum; int mid=(begin+end)/2;pushdown(root,begin,mid,end); return (cx(p[root].ls,begin,mid,begin2,end2)+cx(p[root].rs,mid+1,end,begin2,end2))%mod; } int js(int x,int z){ if (x<0)return 0; int rt=1,ans=0,t=0; for (int i=30;i>=0;i--) if ((1<<i)&x) { bool u; if ((1<<i)&z)u=1;else u=0; ans=(ans+cx(1,0,INF,t+(1<<i)*u,t+(1<<i)*u+((1<<i)-1)))%mod; t+=(1<<i)*(u^1); } else { bool u; if ((1<<i)&z)u=1;else u=0; t+=(1<<i)*u; } ans=(ans+cx(1,0,INF,t,t))%mod; return ans; } int main() { scanf("%d%d",&n,&q); for (int i=1;i<=n;i++)scanf("%d",&a[i]); sort(a+1,a+n+1); a[n+1]=INF+1; for (int i=1;i<=n;i++) if (a[i]!=a[i+1])gx(1,0,INF,a[i],a[i+1]-1,(long long)i*i%mod); for (int i=1;i<=q;i++) { int l,r,x; scanf("%d%d%d",&l,&r,&x); printf("%d ",((js(r,x)-js(l-1,x))%mod+mod)%mod); } return 0; }