本文移植自个人的原博客(原文

题目链接

COGS 1829

题目描述

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  1. 插入$x$数;

  2. 删除$x$数(若有多个相同的数,因只删除一个);

  3. 查询$x$数的排名(若有多个相同的数,因输出最小的排名);

  4. 查询排名为$x$的数;

  5. 求$x$的前驱(前驱定义为小于$x$,且最大的数);

  6. 求$x$的后继(后继定义为大于$x$,且最小的数)。

输入输出格式

输入格式

第一行为$n$,表示操作的个数,下面$n$行每行有两个数$opt$和$x$,$opt$表示操作的序号$(1<=opt<=6)$,$x$含义如上所示。

$n$的数据范围:$n<=100000$

每个数的数据范围:$[−1e7,1e7]$(BZOJ3224:$[−2e9,2e9]$)

输出格式

对于操作$3,4,5,6$每行输出一个数,表示对应答案。

样例

输入样例

10
1 106465
4 1
1 317721
1 460929
1 644985
1 84185
1 89851
6 81968
1 492737
5 493598

输出样例

106465
84185
492737

题解

各种平衡树的模板题(包括SPLAY,TREAP,SBT等等。然而只会SPLAY)

V1

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
using namespace std;
const int inf = 0x7fffffff;
const int N = 1e5+1;
struct splay{int data,ls,rs,fa,size;}a[N];
int q,root=0,tot=0;
void pushup(int x){a[x].size=a[a[x].ls].size+a[a[x].rs].size+1;}
void zig(int x)
{
int y=a[x].fa;
int z=a[y].fa;
a[y].fa=x;a[x].fa=z;
a[y].ls=a[x].rs;a[a[x].rs].fa=y;a[x].rs=y;
if(y==a[z].ls) a[z].ls=x;
else a[z].rs=x;
pushup(y);
}
void zag(int x)
{
int y=a[x].fa;
int z=a[y].fa;
a[y].fa=x,a[x].fa=z;
a[y].rs=a[x].ls;a[a[x].ls].fa=y;a[x].ls=y;
if(y==a[z].ls) a[z].ls=x;
else a[z].rs=x;
pushup(y);
}
void splay(int x,int s)
{
while (a[x].fa!=s)
{
int y=a[x].fa;
int z=a[y].fa;
if(z==s)
{
if(x==a[y].ls) zig(x);
else zag(x);
break;
}
if(y==a[z].ls)
{
if(x==a[y].ls) zig(y),zig(x);
else zag(x),zig(x);
}
else
{
if(x==a[y].rs) zag(y),zag(x);
else zig(x),zag(x);
}
}
pushup(x);
if (s==0) root=x;
}
int Search(int w)
{
int p,x=root;
while (x)
{
p=x;
if (a[x].data>w) x=a[x].ls;
else x=a[x].rs;
}
return p;
}
void newnode(int &x,int fa,int key)
{
x=++tot;
a[x].ls=a[x].rs=0;
a[x].fa=fa;
a[x].data=key;
}
void insert(int w)
{
if (root==0)
{
newnode(root,0,w);
return;
}
int i=Search(w);
if (w<a[i].data) newnode(a[i].ls,i,w);
else newnode(a[i].rs,i,w);
splay(tot,0);
}
int get(int w)
{
int x=root,ans=tot+1;
while(x)
{
if(a[x].data>w){x=a[x].ls;continue;}
if(a[x].data<w){x=a[x].rs;continue;}
if(a[x].data==w)
{
ans=x;
x=a[x].ls;
}
}
if(ans==tot+1) return -1;
return ans;
}
int getmax(int x){while(a[x].rs)x=a[x].rs;return x;}
int getmin(int x){while (a[x].ls)x=a[x].ls;return x;}
int getpre1(int x){return getmax(a[root].ls);}
int getne1(int x){return getmin(a[root].rs);}
void delet(int w)
{
int x=get(w);
splay(x,0);
int pp=getpre1(x),nn=getne1(x);
splay(pp,0);
splay(nn,root);
int y=a[x].fa;
a[x].fa=0;
if(x==a[y].ls) a[y].ls=0;
else a[x].ls=0;
pushup(y);pushup(root);
}
int find(int w)
{
int x=get(w);
splay(x,0);
return a[a[x].ls].size;
}
int findkth(int x,int k)
{
int s=a[a[x].ls].size;
if (k==s+1) return a[x].data;
if (s>=k) return findkth(a[x].ls,k);
else return findkth(a[x].rs,k-s-1);
}
int getpre(int w)
{
int y=get(w);
insert(w);
if(y!=-1)splay(y,0);
int ans=getmax(a[root].ls);
delet(w);
return a[ans].data;
}
int getne(int w)
{
insert(w);
int ans=getmin(a[root].rs);
delet(w);
return a[ans].data;
}
int main()
{
root=tot=0;
insert(-inf);insert(inf);
scanf("%d",&q);
while(q--)
{
int x,k;
scanf("%d%d",&x,&k);
if(x==1) insert(k);
else if(x==2) delet(k);
else if(x==3) printf("%d\n",find(k));
else if(x==4) printf("%d\n",findkth(root,k+1));
else if(x==5) printf("%d\n",getpre(k));
else if(x==6) printf("%d\n",getne(k));
}
return 0;
}

V2

#include<cstdio>
#include<iostream>
const int N=1e5+5;
int tot,root,size[N],num[N],key[N],fa[N],son[N][2];
void pushup(int x)
{
size[x]=size[son[x][0]]+size[son[x][1]]+num[x];
}
void zg(int x)
{
//push_down(fa[x]);push_down(x);
int y=fa[x],z=fa[y],t=(son[y][0]==x);
if(z) son[z][son[z][1]==y]=x;fa[x]=z;
son[y][!t]=son[x][t];fa[son[y][!t]]=y;
son[x][t]=y;fa[y]=x;
pushup(y);pushup(x);
}
void splay(int x,int f)
{
while(fa[x]!=f)
{
int y=fa[x],z=fa[y];
if(z!=f)
{
if(son[z][0]==y^son[y][0]==x) zg(x);
else zg(y);
}
zg(x);
}
if(!f) root=x;
}
void insert(int &x,int v,int f)
{
if(!x)
{
x=++tot;
son[x][0]=son[x][1]=0;
size[x]=num[x]=1;
key[x]=v;fa[x]=f;
splay(x,0);
return;
}
if(v==key[x])
{
num[x]++;size[x]++;
splay(x,0);
return;
}
insert(son[x][v>key[x]],v,x);
pushup(x);
}
int get(int v)
{
int x=root;
while(x&&v!=key[x]) x=son[x][v>key[x]];
return x;
}
void delet(int x)
{
x=get(x);if(!x) return;
splay(x,0);
if(num[x]>1) {num[x]--;size[x]--;return;}
if(!son[x][0]||!son[x][1]) root=son[x][0]+son[x][1];
else
{
int y=son[x][1];while(son[y][0]) y=son[y][0];
splay(y,x);
son[y][0]=son[x][0];fa[son[y][0]]=y;
root=y;
}
fa[root]=0;
pushup(root);
}
int getrank(int v)
{
insert(root,v,0);
int ans=size[son[root][0]]+1;
delet(v);
return ans;
}
int kth(int x)
{
int y=root;
while(x<=size[son[y][0]]||x>size[son[y][0]]+num[y])
{
if(x<=size[son[y][0]]) y=son[y][0];
else x-=size[son[y][0]]+num[y],y=son[y][1];
}
return key[y];
}
int pre(int v)
{
insert(root,v,0);
int x=son[root][0];while(son[x][1]) x=son[x][1];
delet(v);
return key[x];
}
int ne(int v)
{
insert(root,v,0);
int x=son[root][1];while(son[x][0]) x=son[x][0];
delet(v);
return key[x];
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1,x,y;i<=n;++i)
{
if(x==1) insert(root,y,0);
if(x==2) delet(y);
if(x==3) printf("%d\n",getrank(y));
if(x==4) printf("%d\n",kth(y));
if(x==5) printf("%d\n",pre(y));
if(x==6) printf("%d\n",ne(y));
}
}