http://codeforces.com/contest/665/problem/E
给定一个序列,问其中有多少个区间,所有数字异或起来 >= k
看到异或,就应该想到异或的性质,A^B^B = A,所以,用sum[i]表示1--i的异或值,那么i...j的异或值就是sum[j] ^ sum[i - 1]
所以原问题转化为在一个数组中,任选两个数,异或值 >= k
首先所有数字 <= 1e9
那么前缀异或值 <= 2^31 - 1
那么用int就够了,从30位开始插入字典树,(1 << 31)是溢出的
查找的时候,分两类
1、当前的位置中k的二进制位是0,就是第15位,如果k的二进制位是0,那么往反方向走字典树的子树总是可以得,这个时候只需要加上去子树大小即可。
2、如果那一位k是1,我们没法马上判断贡献,那么只能往反方向贪心走字典树。
注意的就是叶子节点有可能没走到,需要最后加上来。
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
#define inf (0x3f3f3f3f)
typedef long long int LL;
#include <iostream>
#include <sstream>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <string>
const int maxn = 1e6 + 20;
const int N = 2;
int a[maxn];
int sum[maxn];
int n, k;
struct node {
struct node *pNext[N];
int cnt;
}tree[maxn * 40];
int t;
struct node * create() {
struct node *p = &tree[t++];
for (int i = 0; i < N; ++i) {
p->pNext[i] = 0;
}
p->cnt = 0;
return p;
}
void insert(struct node **T, int val) {
struct node *p = *T;
if (p == NULL) {
p = *T = create();
}
for (int i = 30; i >= 0; --i) {
int id = (val & (1 << i)) > 0;
if (p->pNext[id] == NULL) {
p->pNext[id] = create();
}
p = p->pNext[id];
p->cnt++;
}
return ;
}
LL find(struct node *T, int val) {
struct node *p = T;
LL ans = 0;
if (p == NULL) return ans;
for (int i = 30; i >= 0; --i) {
int id = (val & (1 << i)) > 0;
int id_k = (k & (1 << i)) > 0;
if (id_k == 0) {
if (p->pNext[!id] != NULL)
ans += p->pNext[!id]->cnt; //这边的树全部都可以
if (p->pNext[id] == NULL) return ans;
p = p->pNext[id]; //如果走到了叶子节点,我这里return了,不会重复增加
} else { //id_k = 1
if (p->pNext[!id] == NULL) return ans;
p = p->pNext[!id];
}
}
return ans + p->cnt; //叶子节点有可能没算到,参考样例1
}
void work() {
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; ++i) {
scanf("%d", &a[i]);
sum[i] = sum[i - 1] ^ a[i];
// cout << sum[i] << " ";
}
// printf("%d\n", 1 << 30);
struct node *T = NULL;
insert(&T, sum[1]);
LL ans = 0;
for (int i = 2; i <= n; ++i) {
ans += find(T, sum[i]);
insert(&T, sum[i]);
// cout << find(T, sum[i]) << endl;
}
for (int i = 1; i <= n; ++i) ans += sum[i] >= k;
cout << ans << endl;
return;
}
int main() {
#ifdef local
freopen("data.txt","r",stdin);
#endif
work();
return 0;
}
View Code