思路
- 首先考虑不带区间加的情况,显然容易想到对每个数的每一个二进制位维护一个树状数组。设一个树状数组维护的是二进制的第(k)位,那就每次往里面存(num)的时候在这个树状数组的第(num mod 2^k)这个位置(+1),那么我们最后查询的时候,只要找到对应的那个树状数组,统计答案(query(2^{k+1}-1)-query(2^k-1))即可求得答案。 注意这里对于(2^k)取模其实就是(num)和(2^k)按位求与就行了。
- 然后考虑带区间加的情况。由于是全局加,所以考虑维护一个全局变量(sum)来记录变化量。接着我们来考虑这个东东对答案的影响。考虑当(k=3)时,我们想要的答案区间在([111_{(2)},100_{(2)}])。但是由于要加的这个(sum)使得答案区间发生了变化。这个变化可以分为两种。
- 当一个更小的数通过加上这个(sum)从而达到了这个答案区间。那我们怎么获得这个数呢?那就考虑把左右端点都减去(sum)就可以得到这个区间了。而显然(sum)的二进制下大于(2^{k+1}-1)的部分是没有意义的,因为它没法对这个区间的答案贡献。所以还是先吧(sum)对(2^{k+1})次方取模,方法同上。然后左右端点分别减去就好了。
- 另一个可能是由于进位,可能本来就在这个区间的数加上(sum)之后还在答案区间。显然这部分的答案和前面那一部分是互不包含的。怎么处理呢?其实就是原来的一个数(num)在加上(sum)以后,到了区间([2^{k+1}+2^{k+1}-1,2^{k+1}+2^k])从而它们的答案还是要计算的。然而我这里的栗子实际上只给出了它进一位的情况,显然它进很多位也是合法的。但是其实我们只关心它的后(k)位,所以无论进多少位,处理方法没得区别。就是先让左右端点加上(2^{k+1}),然后按前面那个方法来处理就可以了。
代码
由于刚开始我也不会,所以代码大多借鉴了网上题解的做法,多使用位运算解决这个问题。实际上换成取模也是可以的。但是位运算会快一点。
但是位运算代码可读性是真低
#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <map>
#include <algorithm>
using namespace std;
#define R register
#define LL long long
const int inf = 0x3f3f3f3f;
const int MAXN = (1 << 16) + 10;
inline int read()
{
char a = getchar();
int x = 0, f = 1;
for (; a > '9' || a < '0'; a = getchar())
if (a == '-')
f = -1;
for (; a >= '0' && a <= '9'; a = getchar())
x = x * 10 + a - '0';
return x * f;
}
int sum;
struct BIT
{
private:
int c[MAXN];
inline int lowbit(int x) { return x & -x; }
public:
inline int ask(int x)
{
int ans = 0;
for (; x; x -= lowbit(x))
ans += c[x];
return ans;
}
inline void update(int x, int y)
{
for (; x < MAXN; x += lowbit(x))
c[x] += y;
}
} bit[16];
map<int, int> mp;
int main()
{
freopen("a.in", "r", stdin);
//freopen(".out","w",stdout);
char ch[10];
int x;
int n = read();
while (n--)
{
scanf("%s", ch);
x = read();
if (ch[0] == 'A')
sum += x;
if (ch[0] == 'I')
{
x -= sum;
mp[x]++;
for (R int i = 0; i < 16; i++)
bit[i].update((x & ((1 << (i + 1)) - 1)) + 1, 1);
}
if (ch[0] == 'D')
{
x -= sum;
int cnt = mp[x];
mp[x] = 0;
for (R int i = 0; i < 16; i++)
bit[i].update((x & ((1 << (i + 1)) - 1)) + 1, -cnt);
}
if (ch[0] == 'Q')
{
int ans = 0;
int l = 1 << x;
int r = (1 << (x + 1)) - 1;
ans += bit[x].ask(min(1 << 16, max(0, r - (sum & ((1 << (x + 1)) - 1)) + 1)));
ans -= bit[x].ask(min(1 << 16, max(0, l - (sum & ((1 << (x + 1)) - 1)))));
l |= (1 << (x + 1));
r |= (1 << (x + 1));
ans += bit[x].ask(min(1 << 16, max(0, r - (sum & ((1 << (x + 1)) - 1)) + 1)));
ans -= bit[x].ask(min(1 << 16, max(0, l - (sum & ((1 << (x + 1)) - 1)))));
printf("%d
", ans);
}
}
return 0;
}
给出一个用取模实现的版本,注意负数的影响
#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <map>
#include <algorithm>
using namespace std;
#define R register
#define LL long long
const int inf = 0x3f3f3f3f;
const int MAXN = (1 << 16) + 10;
inline int read()
{
char a = getchar();
int x = 0, f = 1;
for (; a > '9' || a < '0'; a = getchar())
if (a == '-')
f = -1;
for (; a >= '0' && a <= '9'; a = getchar())
x = x * 10 + a - '0';
return x * f;
}
int sum;
struct BIT
{
private:
int c[MAXN];
inline int lowbit(int x) { return x & -x; }
public:
inline int ask(int x)
{
int ans = 0;
for (; x; x -= lowbit(x))
ans += c[x];
return ans;
}
inline void update(int x, int y)
{
//printf("%d
",x);
for (; x < MAXN; x += lowbit(x))
c[x] += y;
}
} bit[16];
map<int, int> mp;
int bs[20];
int main()
{
freopen("a.in", "r", stdin);
freopen("a.out","w",stdout);
char ch[10];
int x;
int n = read();
bs[0]=1;
for(R int i=1;i<=16;i++) bs[i]=bs[i-1]*2;
while (n--)
{
scanf("%s", ch);
x = read();
if (ch[0] == 'A')
sum += x;
if (ch[0] == 'I')
{
x -= sum;
mp[x]++;
for (R int i = 0; i < 16; i++) {
int t=x;
t%=bs[i+1];t+=bs[i+1];t%=bs[i+1];
bit[i].update(t+1,1);
}
}
if (ch[0] == 'D')
{
x -= sum;
int cnt = mp[x];
mp[x] = 0;
for (R int i = 0; i < 16; i++){
int t=x;
t%=bs[i+1];t+=bs[i+1];t%=bs[i+1];
bit[i].update(t+1,-cnt);
}
}
if (ch[0] == 'Q')
{
int ans = 0;
int l = bs[x];
int r = bs[x+1]-1;
int t=sum; t%=bs[x+1];t+=bs[x+1];t%=bs[x+1];
ans += bit[x].ask(min(1 << 16, max(0, r - t+1)));
ans -= bit[x].ask(min(1 << 16, max(0, l - t)));
l |= (1 << (x + 1));
r |= (1 << (x + 1));
ans += bit[x].ask(min(1 << 16, max(0, r - t+1)));
ans -= bit[x].ask(min(1 << 16, max(0, l - t)));
printf("%d
", ans);
}
}
return 0;
}