Abstract
HDU 4343 Interval query
离线查询 倍增祖先
10级的就别看了,11级的读一下。
Body
Source
http://acm.hdu.edu.cn/showproblem.php?pid=4343
Description
给定N(N<=100000)个区间(左闭右开)和M(M<=100000)个询问[l, r],问所有满足[s,t)包含于于[l, r]的区间中最多能选出多少个,使得他们两两不相交。
Solution
首先是一个贪心的想法,如果存在两个区间[s1,t1)和[s2,t2),且[s1,t1)包含于[s2,t2),那么[s2,t2)是可以扔掉的。很显然地,如果[s2,t2)包含于某个解中,那么把它换成[s1,t1)肯定个数不变。
所以就可以把包含了其它区间的那些区间去掉。具体做法是天王想的,特别抽,详见代码。
其实这步做不做无所谓,因为最坏情况下不会扔掉任何一个区间。但事实证明做了这一步就可以O(N*M)暴力水过……
剩下来的区间按左端点升序排序,那么右端点一定也是升序的。
然后求方案其实还是一个贪心的想法,对于一个询问[l, r],设i=max{i|t[i]<=r},如果s[i]>=l,那么就选上i,然后设prev[i]=max{j|j<i且t[j]<=s[i]},也就是prev[i]为i之前最右边的和i不相交的区间,如果s[prev[i]]>=l,那么就选上prev[i]。就这样不停贪心地向前找,直到找到某个j=max{j|s[j]<l},那么j就不能选了。这么做的正确性也很显然,就不证了。
于是我们发现,这些线段就形成了一个树型结构,除了第一个线段,每个线段都有其前继。线段的排序是这颗有向树的拓扑排序的一种。
如果将询问排序按右端点排序后,容易用均摊线性时间找到对于询问[l, r]的终止节点i=max{i|t[i]<=r}。那么接下来就是顺着前继向前找,直到发现j=max{j|s[j]<l}。
暴力找的复杂度是O(N),有M个询问,也就是总共是O(N*M)。由于数据比较水,在第一个贪心思路把没用的区间去掉后可以750ms通过(RunID=6472102,似乎比不加这个优化然后正经做的都要快……)。
稍微思考一下。顺着前继向前找,一直找到最左边,这样找到的节点形成了一条链。很显然,链上的左端点是单调的,这给了我们二分的启示。
但是普通的二分需要存储每一条链,这在时间和空间上都是不允许的。
将ans进行二进制表示,令ans=2^k*b[k]+2^(k-1)*b[k-1]+...+2^0*b[0],b[t]=0或1, t=k,k-1,...,0,可以发现,二分的过程事实上就是根据当前的ans判断出每一个b[t], t=k,k-1,...,0是1还是0的过程。
这样就可以利用类似于最近公共祖先(LCA)的倍增法的思路,记录i的2^j祖先,也就是i节点向前找2^j是什么节点。然后就可以根据每次向前找2^j, j=k,k-1,...,0后,节点的左端点是否还满足s[j]>=l来判断ans的二进制表示中第j位是1还是0了。
Code
#include <cstdio> #include <cstring> #include <vector> #include <algorithm> using namespace std; struct ss { ss() {} ss(int a, int b): s(a), t(b) {} int s, t; int p, q; bool vis; bool operator<(const ss &rhs)const { return t<rhs.t; } void write() { printf("%d %d\n", s, t); } }; bool cmp(const ss &u, const ss &v) { if (u.p==v.p) return u.q>v.q; return u.p<v.p; } struct sq { int id; int s, t; bool operator<(const sq &rhs) const { return t<rhs.t; } }q[100010]; const int INF = 1000000001; int N, M; ss hmr[100010], mdk[100010]; int ans[100010]; int prev[100010][22]; int main() { int i, j, k; while (~scanf("%d%d", &N, &M)) { for (i = 0; i < N; ++i) { scanf("%d%d", &hmr[i].s, &hmr[i].t); hmr[i].p = hmr[i].t; hmr[i].q = INF+hmr[i].s; hmr[i].vis = 0; } sort(hmr, hmr+N, cmp); for (i = 0; i < N; ) { j = i+1; while (j<N && hmr[j].q <= hmr[i].q) hmr[j++].vis = 1; i = j; } j = 0; for (i = 0; i < N; ++i) if (!hmr[i].vis) mdk[++j] = hmr[i]; N = j; sort(mdk+1, mdk+N+1); mdk[0] = ss(-2, -1); for (i=0,j=1; j<N; ++i,j<<=1); int bmax = i; k = 1; for (i = 1; i <= N; ++i) { while (mdk[k].t<=mdk[i].s) k++; prev[i][0] = k-1; } prev[0][0] = 0; for (j = 1; j <= bmax; ++j) { prev[0][j] = 0; for (i = 1; i <= N; ++i) { prev[i][j] = prev[prev[i][j-1]][j-1]; } } for (i = 0; i < M; ++i) { scanf("%d%d", &q[i].s, &q[i].t); q[i].id = i; } sort(q, q+M); k = 0; int b; for (i = 0; i < M; ++i) { while (k<=N && mdk[k].t <= q[i].t) k++; int &res = ans[q[i].id], now = k-1; res = 0; for (j=bmax; j>=0; j--) { if (mdk[prev[now][j]].s>=q[i].s) { res |= (1<<j); now = prev[now][j]; } } if (mdk[k-1].s >= q[i].s) res++; } for (i = 0; i < M; ++i) printf("%d\n", ans[i]); } return 0; }