[Codeforces Round #626 (Div. 2, based on Moscow Open Olympiad in Informatics)] -D. Present(异或性质,按位拆分,树桩数组)
D. Present
time limit per test
3 seconds
memory limit per test
512 megabytes
standard input
standard output
Catherine received an array of integers as a gift for March 8. Eventually she grew bored with it, and she started calculated various useless characteristics for it. She succeeded to do it for each one she came up with. But when she came up with another one — xor of all pairwise sums of elements in the array, she realized that she couldn't compute it for a very large array, thus she asked for your help. Can you do it? Formally, you need to compute
Here x⊕yx⊕y is a bitwise XOR operation (i.e. xx ^ yy in many modern programming languages). You can read about it in Wikipedia: https://en.wikipedia.org/wiki/Exclusive_or#Bitwise_operation.
The first line contains a single integer nn (2≤n≤4000002≤n≤400000) — the number of integers in the array.
The second line contains integers a1,a2,…,ana1,a2,…,an (1≤ai≤1071≤ai≤107).
Print a single integer — xor of all pairwise sums of integers in the given array.
1 2
1 2 3
In the first sample case there is only one sum 1+2=31+2=3.
In the second sample case there are three sums: 1+2=31+2=3, 1+3=41+3=4, 2+3=52+3=5. In binary they are represented as 0112⊕1002⊕1012=01020112⊕1002⊕1012=0102, thus the answer is 2.
⊕⊕ is the bitwise xor operation. To define x⊕yx⊕y, consider binary representations of integers xx and yy. We put the ii-th bit of the result to be 1 when exactly one of the ii-th bits of xx and yy is 1. Otherwise, the ii-th bit of the result is put to be 0. For example, 01012⊕00112=0110201012⊕00112=01102.
((a_1 + a_2) oplus (a_1 + a_3) oplus ldots oplus (a_1 + a_n) \ oplus (a_2 + a_3) oplus ldots oplus (a_2 + a_n) \ ldots \ oplus (a_{n-1} + a_n) \)
对于第(mathit k)位(从第0位开始),我们考虑将所有的(a_i)对(2^{k+1})取模后有多少对((i,j))使其(sum=a_i+a_j)的第(mathit k)位为1,
我们知道在取模操作后第k位为1的话,sum要在([2^k, 2^{k+1}),[2^{k+1} + 2^k, 2^{k+2} - 2]) 这2个区间中。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#include <sstream>
#include <bitset>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), ' ', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
#define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c))
#define du2(a,b) scanf("%d %d",&(a),&(b))
#define du1(a) scanf("%d",&(a));
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) { if (a == 0ll) {return 0ll;} a %= MOD; ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;}
ll poww(ll a, ll b) { if (a == 0ll) {return 0ll;} ll ans = 1; while (b) {if (b & 1) {ans = ans * a ;} a = a * a ; b >>= 1;} return ans;}
void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("
void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("
inline long long readll() {long long tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
inline int readint() {int tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
const int maxn = 100000100;
const int inf = 0x3f3f3f3f;
int tree[maxn];
int lowbit(int x)
return (-x)& x;
void add(int pos, int val)
while (pos < maxn)
tree[pos] += val;
pos += lowbit(pos);
int ask(int pos)
int res = 0;
while (pos > 0)
res += tree[pos];
pos -= lowbit(pos);
return res;
int n;
int a[400010];
int query(int l, int r)
return ask(r) - ask(l - 1);
int main()
n = readint();
repd(i, 1, n)
a[i] = readint();
ll ans = 0ll;
for (int i = 0; i <= 24; ++i)
ll res = 0ll;
repd(j, 1, n)
int x = a[j] % (1 << i + 1);
res += query((1 << i) - x, (1 << i + 1) - 1 - x);
res += query((1 << i) + (1 << i + 1) - x, (1 << i + 2) - 2 - x);
add(x, 1);
repd(j, 1, n)
int x = a[j] % (1 << i + 1);
add(x, -1);
if (res & 1)
ans += (1 << i);
", ans );
return 0;