思路:Splay数查找前驱后继
代码:
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #define LS(n) node[(n)].ch[0] #define RS(n) node[(n)].ch[1] using namespace std; typedef long long ll; const int INF = 0x3f3f3f3f; const int maxn = 32767 + 10; int n; int cnt; int root; struct splay{ int ch[2], size, cnt, val, fa; }t[maxn]; int gi(){ int ans = 0 , f = 1; char i = getchar(); while(i<'0'||i>'9'){if(i=='-')f=-1;i=getchar();} while(i>='0'&&i<='9'){ans=ans*10+i-'0';i=getchar();} return ans * f; } void out(int x){ if(t[x].ch[0]) out(t[x].ch[0]); printf("%d ",t[x].val); if(t[x].ch[1]) out(t[x].ch[1]); } int get(int x){ return t[t[x].fa].ch[1] == x; } void up(int x){ t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+t[x].cnt; } void rotate(int x){ int fa = t[x].fa , gfa = t[fa].fa; int d1 = get(x) , d2 = get(fa); t[fa].ch[d1]=t[x].ch[d1^1] , t[t[x].ch[d1^1]].fa=fa; t[gfa].ch[d2]=x , t[x].fa=gfa; t[fa].fa=x , t[x].ch[d1^1]=fa; up(fa); up(x); } void splay(int x,int goal){ while(t[x].fa != goal){ int fa = t[x].fa, gfa = t[fa].fa; int d1 = get(x), d2 = get(fa); if(gfa != goal){ if(d1 == d2) rotate(fa); else rotate(x); } rotate(x); } if(goal == 0) root = x; } int find(int val){ int node = root; while(t[node].val != val && t[node].ch[t[node].val<val]) node = t[node].ch[t[node].val<val]; return node; } void insert(int val){ int node = root, fa = 0; while(t[node].val != val && node) fa = node, node = t[node].ch[t[node].val<val]; if(node) t[node].cnt++; else{ node = ++cnt; if(fa) t[fa].ch[t[fa].val<val] = node; t[node].size = t[node].cnt = 1; t[node].fa = fa; t[node].val = val; } splay(node , 0); } //注意,返回的是结构体下标 //注意修改判断时的等于号 int pre(int val,int kind){ //0前驱,1后继 splay(find(val) , 0); int node = root; if(t[node].val <= val && kind == 0) return node; if(t[node].val >= val && kind == 1) return node; node = t[node].ch[kind]; while(t[node].ch[kind^1]) node = t[node].ch[kind^1]; return node; } void delet(int val){ int last = pre(val,0), next = pre(val,1); splay(last , 0); splay(next , last); if(t[t[next].ch[0]].cnt > 1){ t[t[next].ch[0]].cnt--; splay(t[next].ch[0] , 0); } else t[next].ch[0] = 0; } int kth(int k){ int node = root; if(t[node].size < k) return INF; while(1){ int son = t[node].ch[0]; if(k <= t[son].size) node = son; else if(k > t[son].size+t[node].cnt){ k -= t[son].size+t[node].cnt; node = t[node].ch[1]; } else return t[node].val; } } int get_rank(int val){ splay(find(val) , 0); return t[t[root].ch[0]].size; } int main(){ int a; root = cnt = 0; int ans = 0; insert(INF), insert(-INF); scanf("%d%d", &n, &a); insert(a); ans += a; for(int i = 1; i <= n - 1; i++){ scanf("%d", &a); int p1 = t[pre(a, 0)].val; int p2 = t[pre(a, 1)].val; ans += min(abs(p1 - a), abs(p2 - a)); insert(a); } printf("%d ", ans); return 0; }