第一次看到这题应该是在大一,当时想学一手莫队,但是由于过于菜(其实主要是心理上的原因)没有学成。
原题链接https://www.luogu.com.cn/problem/P1972
这道题的做法还算蛮多的,首先,这道题很明显的一道莫队的板子题,也可以用主席树来维护。我后来看了题解发现还可以巧妙的运用树状数组来解决(但是我这种没有脑子的选手是想不出来的)
这边会陆续更新三种解法。
1.莫队
众所周知,莫队算法是一种玄学骗分区间操作的数据结构算法。它解决的一般是求区间内出现的不同数字的个数。
举个栗子
问题:
有一个长为N序列,有M个询问:在区间[L,R]内,出现了多少个不同的数字。(序列中所有数字均小于K)。<———莫队算法就是解决这种问题的算法(本蒟蒻现在只会最基本的)
问题很简单,那么我们怎么解决呢?
void add ( int pos ) { ++cnt[a[pos]] ; if ( cnt[a[pos]] == 1 ) ++ answer ; } void remove ( int pos ) { -- cnt[a[pos]] ; if ( cnt[a[pos]] == 0 ) -- answer ; } void solve() { int curL = 1, curR = 0 ; // current L R for ( each query [L,R] ) { while ( curL < L ) remove ( curL++ ) ; while ( curL > L ) add ( --curL ) ; while ( curR < R ) add ( ++curR ) ; while ( curR > R ) remove ( curR-- ) ; cout << answer << endl ; // Warning : please notice the order "--","++" and "cur" ; } }
这种做法比暴力快很多(造几组数据手玩一下还是蛮好理解的),事实上,这就是莫队的核心思路。
虽然这种做法比暴力要快很多了, 但是还是非常的慢。
接下来的优化就是算法的核心所在了。这种所发之所以慢是因为curL和curR进行了太多无效的移动,怎么增加他们移动的效率呢?
我们可以把要查询的区间先离线下来,按照他们的左端点、右端点给他们分块排序,就可以大大提升效率。玄学高效吧!
理解了这些,这道项链题就是一道板子题了。
但是,洛谷上这题加强了数据范围卡莫队(对没错,这道莫队的入门题被卡莫队了......)
要交的话可以在https://www.luogu.com.cn/problem/SP3267这个网址进行提交,这边数据进行了弱化。
代码附上
#include <stdio.h> #include <iostream> #include <cstring> #include <algorithm> #include <cmath> #include <queue> #include <map> #include <stack> #pragma GCC optimize(2) #define mm(i,v) memset(i,v,sizeof i); #define mp(a, b) make_pair(a, b) #define one first #define two second //你冷静一点,确认思路再敲!!! using namespace std; typedef long long ll; typedef pair<int, int > PII; const int N = 1e6 + 5, mod = 1e9 + 9, INF = 0x3f3f3f3f; int n, m, answer; int cnt[N], a[N], ans[N]; int tim; inline int read(){ char c=getchar();int x=0,f=1; while(c<'0'||c>'9'){if(c=='-')f=-1; c=getchar();} while(c>='0'&&c<='9'){x=x*10+c-'0'; c=getchar();} return x*f; } struct node { int l, r, id; }list[N]; bool cmp(node a, node b) { return (a.l/tim) == (b.l/tim) ? a.r < b.r : a.l < b.l ; } void add ( int pos ) { ++cnt[a[pos]] ; if ( cnt[a[pos]] == 1 ) ++ answer ; } void remove ( int pos ) { -- cnt[a[pos]] ; if ( cnt[a[pos]] == 0 ) -- answer ; } int main() { cin >> n; int curL = 1, curR = 0; for (int i = 1; i <= n; ++i) a[i] = read(); cin >> m; tim = sqrt(m); for (int i = 1; i <= m; ++i) { list[i].l = read(); list[i].r = read(); list[i].id = i; } sort(list + 1, list + 1 + m, cmp); for (int i = 1; i <= m; ++i) { int L = list[i].l, R = list[i].r; while (curL < L) remove(curL++); while (curL > L) add(--curL); while (curR < R) add(++curR); while (curR > R) remove(curR--); ans[list[i].id] = answer; } for (int i = 1; i <= m; ++i) { printf("%d ", ans[i]); } }
上网课去,溜了溜了。有时间更别的做法。
UPD:2020.7.25
说了有时间更别的做法结果过了5个月才回来更新hhh
主席树做法:last[i]表示i位置的数上一个出现位置,查询区间中last[i]<l的个数,序列建主席树,last权值线段树上就是[0...l-1]的权值和
但是很难受的是因为数据的加强这份代码还是不能通过所有的数据(但是可以通过弱化版)
#include <stdio.h> #include <iostream> #include <cstring> #include <algorithm> #include <cmath> #include <queue> #include <map> #include <stack> #include <sstream> #include <set> #pragma GCC optimize(2) //#define int long long #define mm(i,v) memset(i,v,sizeof i); #define mp(a, b) make_pair(a, b) #define pi acos(-1) #define fi first #define se second //你冷静一点,确认思路再敲!!! using namespace std; typedef long long ll; typedef pair<int, int > PII; priority_queue< PII, vector<PII>, greater<PII> > que; stringstream ssin; // ssin << string while ( ssin >> int) const int N = 2e6 + 5, M = 1e4 + 5, mod = 1e9 + 9, INF = 0x3f3f3f3f; int n, m; inline int read(){ char c=getchar();int x=0,f=1; while(c<'0'||c>'9'){if(c=='-')f=-1; c=getchar();} while(c>='0'&&c<='9'){x=x*10+c-'0'; c=getchar();} return x*f; } struct node { int l, r; int cnt; }tr[N * 4 + N * 17]; int root[N], idx; int last[N], pos[N]; inline int build(int l, int r) { int p = ++idx; if (l == r) return p; int mid = l + r >> 1; tr[p].l = build(l, mid); tr[p].r = build(mid + 1, r); return p; } inline int insert(int p, int l, int r, int x) { int q = ++idx; tr[q] = tr[p]; if (l == r) { tr[q].cnt++; return q; } int mid = l + r >> 1; if (x <= mid) tr[q].l = insert(tr[p].l, l, mid, x); else tr[q].r = insert(tr[p].r, mid + 1, r, x); tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt; return q; } inline int query(int q, int p, int l, int r, int k) { if (l == r) return tr[q].cnt - tr[p].cnt; int cnt = tr[tr[q].l].cnt - tr[tr[p].l].cnt; int mid = l + r >> 1; if (k <= mid) return query(tr[q].l, tr[p].l, l, mid, k); else return cnt + query(tr[q].r, tr[p].r, mid + 1, r, k); } int main() { n = read(); for (int i = 1; i <= n; ++i) { int x; x = read(); last[i] = pos[x]; pos[x] = i; } m = read(); root[0] = build(0, N); for (int i = 1; i <= n; ++i) root[i] = insert(root[i - 1], 0, n, last[i]); while (m--) { int l, r; l = read(); r = read(); printf("%d ", query(root[r], root[l - 1], 0, n, l - 1)); } // system("pause"); return 0; }
最后应该是本题正解(至少是出题人希望的算法)用树状数组解决这个问题
其实想明白之后这道题无论是写法还是时间复杂度都比上面说到的两种解法要优秀。
这种问题我们可以联想到把每组询问都离线下来,按照r从小到大排序,根据r来更新原数组中数字所在位置,同时在更新的过程中要删去一些已经出现过的重复数字。
代码: (个人感觉写的还是满通俗易懂的)
#include <stdio.h> #include <iostream> #include <cstring> #include <algorithm> #include <cmath> #include <queue> #include <map> #include <stack> #include <sstream> #include <set> #pragma GCC optimize(2) //#define int long long #define mm(i,v) memset(i,v,sizeof i); #define mp(a, b) make_pair(a, b) #define pi acos(-1) #define fi first #define se second //你冷静一点,确认思路再敲!!! using namespace std; typedef long long ll; typedef pair<int, int > PII; priority_queue< PII, vector<PII>, greater<PII> > que; stringstream ssin; // ssin << string while ( ssin >> int) const int N = 1e6 + 5, M = 2e5 + 5, mod = 1e9 + 9, INF = 0x3f3f3f3f; int n, m; int a[N]; int tr[N]; inline int read(){ char c=getchar();int x=0,f=1; while(c<'0'||c>'9'){if(c=='-')f=-1; c=getchar();} while(c>='0'&&c<='9'){x=x*10+c-'0'; c=getchar();} return x*f; } struct node { int l, r, id, ans; }list[N]; bool cmp1(node a, node b) { return a.r < b.r; } bool cmp2(node a, node b) { return a.id < b.id; } int lowbit(int x) { return x & -x; } void add(int x, int y) { for (int i = x; i <= n; i += lowbit(i)) { tr[i] += y; } } int sum(int x) { int ans = 0; for (int i = x; i > 0; i -= lowbit(i)) { ans += tr[i]; } return ans; } int vis[N]; int main() { n = read(); for (int i = 1; i <= n; ++i) a[i] = read(); m = read(); for (int i = 1; i <= m; ++i) { list[i].l = read(); list[i].r = read(); list[i].id = i; } sort(list + 1, list + 1 + m, cmp1); int st = 1; for (int i = 1; i <= m; ++i) { for (int j = st; j <= list[i].r; ++j) { if (vis[a[j]]) add(vis[a[j]], -1); add(j, 1); vis[a[j]] = j; } list[i].ans = sum(list[i].r) - sum(list[i].l - 1); st = list[i].r + 1; } sort(list + 1, list + 1 + m, cmp2); for (int i = 1; i <= m; ++i) { printf("%d ", list[i].ans); } // system("pause"); return 0; }
到此,终于把这题结束了。