题目描述:
pupil 发现对于一个十进制数,无论怎么将其的数字重新排列,均不影响其是不是 的倍数。他想研究对于二进制,是否也有类似的性质。
于是他生成了一个长为n 的二进制串,希望你对于这个二进制串的一个子区间,能求出其有多少位置不同的连续子串,满足在重新排列后(可包含前导0)是一个3 的倍数。
两个位置不同的子区间指开始位置不同或结束位置不同。
由于他想尝试尽量多的情况,他有时会修改串中的一个位置,并且会进行多次询问。
首先,看到区间的子串,想到线段树区间合并。
考虑一个子问题:判断一个子串是否满足条件。
首先,二进制化为十进制的方法是对应位×位权。
可以发现,位权(mod3)的值是(1,2,1,2)交替。
设这个子串有x个1放在2上,y个放在1上,则应满足((2x+y)mod3=0),即(x=y(mod3))。
设该子串长度为len,有s个1。
那么,贪心的想,x和y相差0或3。
分情况讨论:
- (len\%2=0),(s\%2=0)。需要满足(s/2<=len/2),显然可以。
- (len\%2=0),(s\%2=1)。即(x+y=s,x-y=3),解得(x=(s+3)/2)。需要满足((s+3)/2<=len/2),即(s<=len-3)。
- (len\%2=1),(s\%2=0)。需要满足(s/2<=(len-1)/2),显然可以。
- (len\%2=1),(s\%2=1)。同理,(x=(s+3)/2)。需要满足((s+3)/2<=(len+1)/2),即(s<=len-2)。
可以发现,s为偶数时一定可以。s为奇数时比较复杂,难以统计(因为即使len确定,满足要求的s有很多,无法进行区间合并)。
考虑容斥,求不合法的。
若不合法,则(s\%2=1)。
分两种情况讨论:
- (len\%2=0),即(s>=len-2)。因为(s\%2=1,s<=len),所以(s=len-1)。
- (len\%2=1),即(s>=len-1)。因为(s\%2=1,s<=len),所以(s=len)。
此外,若(s=1),也不合法。把这部分减去,再加上(s=1)且(len<=2)的情况(被算了两次)。
时间复杂度:(O(mlogn)),常数很大。
注意区间合并的细节。
代码:
#include <stdio.h>
#define ll long long
struct SJd {
ll s[2][2];
SJd() {
s[0][0] = s[0][1] = s[1][0] = s[1][1] = 0;
}
};
SJd operator * (SJd a, SJd b) {
SJd rt;
for (int x = 0; x < 2; x++) {
for (int y = 0; y < 2; y++) {
int z = (x + y) % 2;
rt.s[z][0] += a.s[x][0] * b.s[y][0];
rt.s[z][1] += a.s[x][0] * b.s[y][1] + a.s[x][1] * b.s[y][0];
}
}
return rt;
}
SJd operator + (SJd a, SJd b) {
SJd rt;
rt.s[0][0] = a.s[0][0] + b.s[0][0];
rt.s[0][1] = a.s[0][1] + b.s[0][1];
rt.s[1][0] = a.s[1][0] + b.s[1][0];
rt.s[1][1] = a.s[1][1] + b.s[1][1];
return rt;
}
SJd he[400010],zg[400010],zl[400010],zr[400010];
ll s0[400010],s1[400010],s2[400010];
bool h0[400010],h1[400010];
int sz[100010],l0[400010],l1[400010],r0[400010],r1[400010],su[400010],ma = 0,sl;
void fuz(int i, int j) {
he[i]=he[j];zg[i]=zg[j];
zl[i]=zl[j];zr[i]=zr[j];
s0[i]=s0[j];s1[i]=s1[j];s2[i]=s2[j];
h0[i]=h0[j];h1[i]=h1[j];
l0[i]=l0[j];l1[i]=l1[j];
r0[i]=r0[j];r1[i]=r1[j];
su[i]=su[j];
}
void merge(int i, int cl, int cr, int m) {
h0[i] = h0[cl] && h0[cr];
h1[i] = (h1[cl] && h0[cr]) || (h0[cl] && h1[cr]);
l0[i] = l0[cl];
if (h0[cl]) l0[i] += l0[cr];
r0[i] = r0[cr];
if (h0[cr]) r0[i] += r0[cl];
l1[i] = l1[cl];
if (h0[cl]) l1[i] += l1[cr];
if (h1[cl]) l1[i] += l0[cr];
r1[i] = r1[cr];
if (h0[cr]) r1[i] += r1[cl];
if (h1[cr]) r1[i] += r0[cl];
s0[i] = s0[cl] + s0[cr] + 1ll * r0[cl] * l0[cr];
s1[i] = s1[cl] + s1[cr] + 1ll * r0[cl] * l1[cr] + 1ll * r1[cl] * l0[cr];
s2[i] = s2[cl] + s2[cr] + (sz[m - 1] + sz[m] == 1);
zg[i] = zg[cl] * zg[cr];
zl[i] = zl[cl] + zg[cl] * zl[cr];
zr[i] = zr[cr] + zg[cr] * zr[cl];
he[i] = he[cl] + he[cr] + zr[cl] * zl[cr];
su[i] = su[cl] + su[cr];
}
void pushup(int i, int l, int r) {
int m = (l + r) >> 1;
merge(i, i << 1, (i << 1) | 1, m);
}
void getddz(int i, int x) {
l0[i] = r0[i] = s0[i] = h0[i] = (x == 0);
l1[i] = r1[i] = s1[i] = h1[i] = (x == 1);
he[i] = zg[i] = zl[i] = zr[i] = SJd();
he[i].s[1][x ^ 1] += 1;
zg[i].s[1][x ^ 1] += 1;
zl[i].s[1][x ^ 1] += 1;
zr[i].s[1][x ^ 1] += 1;
su[i] = x;
s2[i] = 0;
}
void jianshu(int i, int l, int r) {
if (i > ma) ma = i;
if (l + 1 == r) {
getddz(i, sz[l]);
return;
}
int m = (l + r) >> 1;
jianshu(i << 1, l, m);
jianshu((i << 1) | 1, m, r);
pushup(i, l, r);
}
void xiugai(int i, int l, int r, int j) {
if (l + 1 == r) {
sz[l] ^= 1;
getddz(i, sz[l]);
return;
}
int m = (l + r) >> 1;
if (j < m) xiugai(i << 1, l, m, j);
else xiugai((i << 1) | 1, m, r, j);
pushup(i, l, r);
}
void getans(int i, int l, int r, int L, int R) {
if (R <= l || r <= L) return;
if (L <= l && r <= R) {
if (sl == ma) fuz(sl + 1, i);
else merge(sl + 1, sl, i, l);
sl += 1;
return;
}
int m = (l + r) >> 1;
getans(i << 1, l, m, L, R);
getans((i << 1) | 1, m, r, L, R);
}
ll getans(int n, int l, int r) {
int s = r - l + 1;
ll zs = 1ll * s * (s + 1) / 2;
sl = ma;
getans(1, 1, n + 1, l, r + 1);
zs -= (he[sl].s[0][1] + he[sl].s[1][0] + s1[sl]);
zs += (su[sl] + s2[sl]);
return zs;
}
int main() {
int n,m;
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &sz[i]);
scanf("%d", &m);
jianshu(1, 1, n + 1);
for (int i = 0; i < m; i++) {
int lx;
scanf("%d", &lx);
if (lx == 1) {
int a;
scanf("%d", &a);
xiugai(1, 1, n + 1, a);
} else {
int a,b;
scanf("%d%d", &a, &b);
printf("%lld
", getans(n, a, b));
}
}
return 0;
}