Splay
Splay是一种二叉排序树,空间效率:O(n),时间效率:O(logn)内完成插入、查找、删除操作,优点:每次查询会调整树的结构,使被查询频率高的条目更靠近树根。
Tree Rotation
树的旋转是splay的基础,对于二叉查找树来说,树的旋转不破坏查找树的结构。
Splaying
Splaying是Splay Tree中的基本操作,为了让被查询的条目更接近树根,Splay Tree使用了树的旋转操作,同时保证二叉排序树的性质不变。
Splaying的操作受以下三种因素影响:
- 节点x是父节点p的左孩子还是右孩子
- 节点p是不是根节点,如果不是
- 节点p是父节点g的左孩子还是右孩子
同时有三种基本操作:
Zig Step
当p为根节点时,进行zip step操作。
当x是p的左孩子时,对x右旋;
当x是p的右孩子时,对x左旋。
Zig-Zig Step
当p不是根节点,且x和p同为左孩子或右孩子时进行Zig-Zig操作。
当x和p同为左孩子时,依次将p和x右旋;
当x和p同为右孩子时,依次将p和x左旋。
Zig-Zag Step
当p不是根节点,且x和p不同为左孩子或右孩子时,进行Zig-Zag操作。
当p为左孩子,x为右孩子时,将x左旋后再右旋。
当p为右孩子,x为左孩子时,将x右旋后再左旋。
下面是splay的伪代码:
P(X) : 获得X的父节点,G(X) : 获得X的祖父节点(=P(P(X)))。
Function Buttom-up-splay:
Do
If X 是 P(X) 的左子结点 Then
If G(X) 为空 Then
X 绕 P(X)右旋
Else If P(X)是G(X)的左子结点
P(X) 绕G(X)右旋
X 绕P(X)右旋
Else
X绕P(X)右旋
X绕P(X)左旋 (P(X)和上面一句的不同,是原来的G(X))
Endif
Else If X 是 P(X) 的右子结点 Then
If G(X) 为空 Then
X 绕 P(X)左旋
Else If P(X)是G(X)的右子结点
P(X) 绕G(X)左旋
X 绕P(X)左旋
Else
X绕P(X)左旋
X绕P(X)右旋 (P(X)和上面一句的不同,是原来的G(X))
Endif
Endif
While (P(X) != NULL)
EndFunction
仔细分析zig-zag,可以发现,其实zig-zag就是两次zig。因此上面的代码可以简化:
Function Buttom-up-splay:
Do
If X 是 P(X) 的左子结点 Then
If P(X)是G(X)的左子结点
P(X) 绕G(X)右旋
Endif
X 绕P(X)右旋
Else If X 是 P(X) 的右子结点 Then
If P(X)是G(X)的右子结点
P(X) 绕G(X)左旋
Endif
X 绕P(X)左旋
Endif
While (P(X) != NULL)
EndFunction
Function Top-Down-Splay
Do
If X 小于 T Then
If X 等于 T 的左子结点 Then
右连接
ElseIf X 小于 T 的左子结点 Then
T的左子节点绕T右旋
右连接
Else X大于 T 的左子结点 Then
右连接
左连接
EndIf
ElseIf X大于 T Then
IF X 等于 T 的右子结点 Then
左连接
ElseIf X 大于 T 的右子结点 Then
T的右子节点绕T左旋
左连接
Else X小于 T 的右子结点‘ Then
左连接
右连接
EndIf
EndIf
While !(找到 X或遇到空节点)
组合左中右树
EndFunction
模板代码如下:
#include <iostream>
using namespace std;
#define MAXN 100010
struct Node{
int key, sz, cnt;
Node *ch[2], *pnt;//左右儿子和父亲
Node(){}
Node(int x, int y, int z){
key = x, sz = y, cnt = z;
}
void rs(){
sz = ch[0]->sz + ch[1]->sz + cnt;
}
}nil(0, 0, 0), *NIL = &nil;
struct Splay{//伸展树结构体类型
Node *root;
int ncnt;//计算key值不同的结点数,注意已经去重了
Node nod[MAXN];
void init(){// 首先要初始化
root = NIL;
ncnt = 0;
}
void rotate(Node *x, bool d){//旋转操作,d为true表示右旋
Node *y = x->pnt;
y->ch[!d] = x->ch[d];
if (x->ch[d] != NIL)
x->ch[d]->pnt = y;
x->pnt = y->pnt;
if (y->pnt != NIL){
if (y == y->pnt->ch[d])
y->pnt->ch[d] = x;
else
y->pnt->ch[!d] = x;
}
x->ch[d] = y;
y->pnt = x;
y->rs();
x->rs();
}
void splay(Node *x, Node *target){//将x伸展到target的儿子位置处
Node *y;
while (x->pnt != target){
y = x->pnt;
if (x == y->ch[0]){
if (y->pnt != target && y == y->pnt->ch[0])
rotate(y, true);
rotate(x, true);
}
else{
if (y->pnt != target && y == y->pnt->ch[1])
rotate(y, false);
rotate(x, false);
}
}
if (target == NIL)
root = x;
}
/************************以上一般不用修改************************/
void insert(int key){//插入一个值
if (root == NIL){
ncnt = 0;
root = &nod[++ncnt];
root->ch[0] = root->ch[1] = root->pnt = NIL;
root->key = key;
root->sz = root->cnt = 1;
return;
}
Node *x = root, *y;
while (1){
x->sz++;
if (key == x->key){
x->cnt++;
x->rs();
y = x;
break;
}
else if (key < x->key){
if (x->ch[0] != NIL)
x = x->ch[0];
else{
x->ch[0] = &nod[++ncnt];
y = x->ch[0];
y->key = key;
y->sz = y->cnt = 1;
y->ch[0] = y->ch[1] = NIL;
y->pnt = x;
break;
}
}
else{
if (x->ch[1] != NIL)
x = x->ch[1];
else{
x->ch[1] = &nod[++ncnt];
y = x->ch[1];
y->key = key;
y->sz = y->cnt = 1;
y->ch[0] = y->ch[1] = NIL;
y->pnt = x;
break;
}
}
}
splay(y, NIL);
}
Node* search(int key){//查找一个值,返回指针
if (root == NIL)
return NIL;
Node *x = root, *y = NIL;
while (1){
if (key == x->key){
y = x;
break;
}
else if (key > x->key){
if (x->ch[1] != NIL)
x = x->ch[1];
else
break;
}
else{
if (x->ch[0] != NIL)
x = x->ch[0];
else
break;
}
}
splay(x, NIL);
return y;
}
Node* searchmin(Node *x){//查找最小值,返回指针
Node *y = x->pnt;
while (x->ch[0] != NIL){//遍历到最左的儿子就是最小值
x = x->ch[0];
}
splay(x, y);
return x;
}
void del(int key){//删除一个值
if (root == NIL)
return;
Node *x = search(key), *y;
if (x == NIL)
return;
if (x->cnt > 1){
x->cnt--;
x->rs();
return;
}
else if (x->ch[0] == NIL && x->ch[1] == NIL){
init();
return;
}
else if (x->ch[0] == NIL){
root = x->ch[1];
x->ch[1]->pnt = NIL;
return;
}
else if (x->ch[1] == NIL){
root = x->ch[0];
x->ch[0]->pnt = NIL;
return;
}
y = searchmin(x->ch[1]);
y->pnt = NIL;
y->ch[0] = x->ch[0];
x->ch[0]->pnt = y;
y->rs();
root = y;
}
int rank(int key){//求结点高度
Node *x = search(key);
if (x == NIL)
return 0;
return x->ch[0]->sz + 1/* or x->cnt*/;
}
Node* findk(int kth){//查找第k小的值
if (root == NIL || kth > root->sz)
return NIL;
Node *x = root;
while (1){
if (x->ch[0]->sz +1 <= kth && kth <= x->ch[0]->sz + x->cnt)
break;
else if (kth <= x->ch[0]->sz)
x = x->ch[0];
else{
kth -= x->ch[0]->sz + x->cnt;
x = x->ch[1];
}
}
splay(x, NIL);
return x;
}
}sp;
int main(){
sp.init();
sp.insert(10);
sp.insert(2);
sp.insert(2);
sp.insert(2);
sp.insert(3);
sp.insert(3);
sp.insert(10);
for (int i = 1; i <= 7; i++)
cout << sp.findk(i)->key << endl;
cout<<sp.searchmin(sp.root)->key<<endl;
sp.del(2);
sp.del(3);
sp.del(1);
return 0;
}
应用
Splay Tree可以方便的解决一些区间问题,根据不同形状二叉树先序遍历结果不变的特性,可以将区间按顺序建二叉查找树。
每次自下而上的一套splay都可以将x移动到根节点的位置,利用这个特性,可以方便的利用Lazy的思想进行区间操作。
对于每个节点记录size,代表子树中节点的数目,这样就可以很方便地查找区间中的第k小或第k大元素。
对于一段要处理的区间[x, y],首先splay x-1到root,再splay y+1到root的右孩子,这时root的右孩子的左孩子对应子树就是整个区间。
这样,大部分区间问题都可以很方便的解决,操作同样也适用于一个或多个条目的添加或删除,和区间的移动。
题意:求区间第k小数,可以用划分树来做,这里区间不会重叠,所以不可能有首首相同或尾尾相同的情况,读入所有区间,按照右端由小到大排序。然后通过维护splay进行第k小元素的查询操作。
代码如下:
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int maxn=100100;
int n,m,sorted[maxn],tree[23][maxn],toleft[23][maxn];
void Build(int l,int r,int deep)
{
if(l==r)
return;
int mid=(l+r)>>1;
int same=mid-l+1;
for(int i=l;i<=r;i++)
if(tree[deep][i]<sorted[mid])
same--;
int ls=l,rs=mid+1;
for(int i=l;i<=r;i++)
{
toleft[deep][i]=toleft[deep][i-1];
if(tree[deep][i]<sorted[mid])
{
tree[deep+1][ls++]=tree[deep][i];
toleft[deep][i]++;
}
else if(tree[deep][i]==sorted[mid])
{
if(same)
{
tree[deep+1][ls++]=tree[deep][i];
toleft[deep][i]++;
same--;
}
else
tree[deep+1][rs++]=tree[deep][i];
}
else
tree[deep+1][rs++]=tree[deep][i];
}
Build(l,mid,deep+1);
Build(mid+1,r,deep+1);
}
int Query(int l,int r,int L,int R,int deep,int k)
{
if(l==r)
return tree[deep][l];
int mid=(L+R)>>1;
int x=toleft[deep][l-1]-toleft[deep][L-1];
int y=toleft[deep][r]-toleft[deep][L-1];
int ry=r-L-y;
int rx=l-L-x;
int cnt=y-x;
if(cnt>=k)
return Query(L+x,L+y-1,L,mid,deep+1,k);
else
return Query(mid+rx+1,mid+ry+1,mid+1,R,deep+1,k-cnt);
}
int main()
{
while(scanf("%d%d",&n,&m)!=EOF)
{
for(int i=1;i<=n;i++)
{
scanf("%d",&sorted[i]);
tree[0][i]=sorted[i];
}
sort(sorted+1,sorted+1+n);
Build(1,n,0);
while(m--)
{
int a,b,k;
scanf("%d%d%d",&a,&b,&k);
printf("%d\n",Query(a,b,1,n,0,k));
}
}
return 0;
}
题意:给出一个n个数的数列a,对于第i个元素ai定义fi=min(abs(ai-aj)),(1<=j<i),其中f1=a1。输出sum(fi) (1<=i<=n)
splay模板题,代码如下:
#include<iostream>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<cstdio>
#include<set>
#include<map>
#include<vector>
#include<cstring>
#include<stack>
#include<cmath>
#include<queue>
using namespace std;
#define CL(x,v); memset(x,v,sizeof(x));
#define INF 0x3f3f3f3f
#define LL long long
#define REP(i,r,n) for(int i=r;i<=n;i++)
#define RREP(i,n,r) for(int i=n;i>=r;i--)
const int MAXN=200010;
const int mod=1000000;
struct SplayTree {
int sz[MAXN];
int ch[MAXN][2];
int pre[MAXN];
int rt,top;
inline void up(int x){
sz[x] = cnt[x] + sz[ ch[x][0] ] + sz[ ch[x][1] ];
}
inline void Rotate(int x,int f){
int y=pre[x];
ch[y][!f] = ch[x][f];
pre[ ch[x][f] ] = y;
pre[x] = pre[y];
if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] =x;
ch[x][f] = y;
pre[y] = x;
up(y);
}
inline void Splay(int x,int goal){//将x旋转到goal的下面
while(pre[x] != goal){
if(pre[pre[x]] == goal) Rotate(x , ch[pre[x]][0] == x);
else {
int y=pre[x],z=pre[y];
int f = (ch[z][0]==y);
if(ch[y][f] == x) Rotate(x,!f),Rotate(x,f);
else Rotate(y,f),Rotate(x,f);
}
}
up(x);
if(goal==0) rt=x;
}
inline void RTO(int k,int goal){//将第k位数旋转到goal的下面
int x=rt;
while(sz[ ch[x][0] ] != k-1) {
if(k < sz[ ch[x][0] ]+1) x=ch[x][0];
else {
k-=(sz[ ch[x][0] ]+1);
x = ch[x][1];
}
}
Splay(x,goal);
}
inline void vist(int x){
if(x){
printf("结点%2d : 左儿子 %2d 右儿子 %2d %2d sz=%d\n",x,ch[x][0],ch[x][1],val[x],sz[x]);
vist(ch[x][0]);
vist(ch[x][1]);
}
}
inline void Newnode(int &x,int c){
x=++top;
ch[x][0] = ch[x][1] = pre[x] = 0;
sz[x]=1; cnt[x]=1;
val[x] = c;
}
inline void init(){
ch[0][0]=ch[0][1]=pre[0]=sz[0]=0;
rt=top=0; cnt[0]=0;
Newnode(rt,-INF);
Newnode(ch[rt][1],INF);
pre[top]=rt;
sz[rt]=2;
}
inline void Insert(int &x,int key,int f){
if(!x) {
Newnode(x,key);
pre[x]=f;
Splay(x,0);//效率的保证
return ;
}
if(key==val[x]){
cnt[x]++;
sz[x]++;
Splay(x,0);//不加会超时,囧啊
return ;
}else if(key<val[x]) {
Insert(ch[x][0],key,x);
} else {
Insert(ch[x][1],key,x);
}
up(x);
}
void findpre(int x,int key,int &ans){
if(!x) return ;
if(val[x] <= key){
ans=val[x];
findpre(ch[x][1],key,ans);
} else
findpre(ch[x][0],key,ans);
}
void findsucc(int x,int key,int &ans){
if(!x) return ;
if(val[x]>=key) {
ans=val[x];
findsucc(ch[x][0],key,ans);
} else
findsucc(ch[x][1],key,ans);
}
int cnt[MAXN];
int val[MAXN];
}spt;
int main()
{
int n;
scanf("%d",&n);
spt.init();
int ans=0;
int a;
scanf("%d",&a);
spt.Insert(spt.rt,a,0);
ans=a;
n--;
while(n--)
{
a=0;//不知道为什么这里要赋值为0,不赋值就wa!!
scanf("%d",&a);
int x,y;
spt.findpre(spt.rt,a,x);
spt.findsucc(spt.rt,a,y);
if(abs(a-x)<=abs(a-y))
{
ans+=abs(a-x);
}
else
{
ans+=abs(a-y);
}
spt.Insert(spt.rt,a,0);
}
printf("%d\n",ans);
return 0;
}