思路:
线段树或者分块
遍历 1 - n - 1,求 区间[i + 1, min(a[i], n)]大于等于 i 的个数,累加起来
线段树:
#include<bits/stdc++.h> using namespace std; #define LL long long #define pb push_back #define ls rt << 1, l, m #define rs rt << 1 | 1, m + 1, r #define mem(a, b) memset(a, b, sizeof(a)) const int N = 2e5 + 5; vector<int> vc[N<<2]; int a[N]; void build(int rt, int l, int r) { if (l == r) { vc[rt].pb(a[l]); return ; } for (int i = l; i <= r; i++) vc[rt].pb(a[i]); sort(vc[rt].begin(), vc[rt].end()); int m = l + r >> 1; build(ls); build(rs); } int query(int L, int R, int rt, int l, int r) { if (L > R) return 0; if (L <= l && r <= R) { return vc[rt].size()-(lower_bound(vc[rt].begin(), vc[rt].end(), L - 1) - vc[rt].begin()); } int ans = 0; int m = l + r >> 1; if (L <= m) ans += query(L, R, ls); if (R > m) ans += query(L, R, rs); return ans; } int main() { int n; scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); build(1, 1, n); LL ans = 0; for (int i = 1; i < n; i++) { ans += query(i + 1, min(a[i], n), 1, 1, n); } printf("%lld ",ans); return 0; }
分块:
#include<bits/stdc++.h> using namespace std; #define LL long long #define pb push_back #define mem(a, b) memset(a, b, sizeof(a)) const int N = 2e5 + 5; int a[N], block[1005], belong[N]; int blo; vector<int> vc[1005]; int query(int L, int R){ if (R < L) return 0; int ans = 0; if (belong[L] == belong[R]) { for (int i = L; i <= R; i++) { if (a[i] >= L - 1) ans++; } return ans; } for (int i = L; i <= belong[L] * blo; i++){ if (a[i] >= L - 1) ans++; } for (int i = belong[L] + 1; i <= belong[R] - 1; i++) { ans += vc[i].size() - (lower_bound(vc[i].begin(), vc[i].end(), L - 1) - vc[i].begin()); } for (int i = (belong[R] - 1) * blo + 1; i <= R; i++) { if (a[i] >= L - 1) ans++; } return ans; } int main() { int n; scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); blo = sqrt(n); for (int i = 1; i <= n; i++) { belong[i] = (i - 1) / blo + 1; } for (int i = 1; i <= n; i++) { vc[belong[i]].pb(a[i]); } for (int i = 1; i <= belong[n]; i++) { sort(vc[i].begin(), vc[i].end()); } LL ans = 0; for (int i = 1; i <= n - 1; i++) { ans += query(i + 1, min(n, a[i])); //cout << ans << endl; } printf("%lld ", ans); return 0; }