题目
给定两个序列a和b,每个序列中可能含有重复的数字。
一个配对(i,j)是一个好配对当从第一个序列中选出一个数ai,再从第二个序列中选出一个数bj且满足ai>bj。
给出两个序列,问存在多少个好配对。
题目链接: 好配对
有题目要求,知道题目的数据量比较大:a和b中分别最多有10^5种不同数字,每个数字最多有10^4个。因此,要求算法有O(nlogn)的时间复杂度。
一开始使用了两个map,map1为序列a中的数字以及对应的个数构成的数对;map2为对于序列a中的数字x,序列b中小于x的数字的个数。这样在第一次输入序列a,时候创建map1,以及将map2中的value均设置为0;在输入序列b时,若当前读取数值为x,个数为y,从map1的末尾向前查找直到map1中当前的key值小于等于x,在经过的那些(key, value)对中,value均加上y,表示在序列b中小于key值的数字个数增加y个。
最后,从头到尾遍历一遍 map1和map2, 求和map1[key]*map2[key]就得到最终结果。
结果华丽的超时了: 在对b序列中的每个数字,从末尾到首部遍历map1,构成了O(n^2)的复杂度了。。
超时之后,朝着 O(nlogn)的复杂度方向努力:使用平衡二叉树节点维持数值x,节点中等于x的个数,节点所代表的子树的总数字的个数。在读取序列a的时候,构建这棵平衡二叉树,复杂度为O(nlogn);在读取序列b的时候,对b中的每个数字x,从该平衡二叉树上获得大于x的数字的总个数sum(时间复杂度O(logn),最终结果加上 y*sum.
总时间复杂度为 O(nlogn)
平衡二叉树使用treap来实现。
实现
#include<stdio.h> #include<string.h> #include<iostream> #include<string> #include<set> #include<map> #include<vector> #include<queue> #include<stack> #include<unordered_map> #include<unordered_set> #include<algorithm> using namespace std; struct Node{ int val; int count; int sum; int priority; Node* childs[2]; Node(){ val = count = sum = 0; childs[0] = childs[1] = NULL; priority = rand(); } void Update(){ sum = count; if (childs[0]) sum += childs[0]->sum; if (childs[1]) sum += childs[1]->sum; } }; struct Treap{ Node* root; Treap(){ root = NULL; } void Delete(Node*& node){ if (!node) return; if (node->childs[0]) Delete(node->childs[0]); if (node->childs[1]) Delete(node->childs[1]); delete node; node = NULL; //注意赋值为NULL,否则在反复使用treap时出错 } void Rotate(Node*& node, bool dir){ Node* ch = node->childs[dir]; node->childs[dir] = ch->childs[!dir]; ch->childs[!dir] = node; node->Update(); //注意更新,因为此时修改了树的结构 node = ch; } void Insert(Node*& node, int val, int count){ if (node == NULL){ node = new Node(); node->val = val; node->sum = node->count = count; return; } if (node->val == val){ node->count += count; node->sum += count; return; } bool ch = node->val < val; Insert(node->childs[ch], val, count); if (node->childs[ch]->priority > node->priority){ Rotate(node, ch); } node->Update(); //更新,此时修改了树的结构 } int Bigger(Node* node, int val){ if (!node) return 0; if (node->val == val) return (node->childs[1]? node->childs[1]->sum:0); else if (node->val < val) return Bigger(node->childs[1], val); else{ return (node->childs[1] ? node->childs[1]->sum : 0) + node->count + Bigger(node->childs[0], val); } } }; int main(){ int T, n, m, x, y; scanf("%d", &T); Treap treap; while (T--){ scanf("%d %d", &n, &m); treap.Delete(treap.root); for (int i = 0; i < n; i++){ scanf("%d %d", &x, &y); treap.Insert(treap.root, x, y); } long long result = 0; for (int i = 0; i < m; i++){ scanf("%d %d", &x, &y); long long int bigger = treap.Bigger(treap.root, x); result += y*bigger; } printf("%lld ", result); } return 0; }