伸展树(splay)

@sixing  October 26, 2019

功能

  • 序列pos位置后面插入一个数
  • 删除pos位置的数
  • 区间[a,b]中的数都加上value
  • 区间[a,b]中的数翻转
  • 区间[a,b]中的数向后循环移t位
  • 取[a,b]中的最值
  • 求[a,b]的区间和

模板

#include <iostream>
#include <stack>
#include <vector>
#include <cstdio>
using namespace std;
typedef long long ll;
const ll N=1e6+5, inf=0x3f3f3f3f;
typedef struct splaynode* node;
struct splaynode {
    node pre, ch[2];
    ll value, lazy, min, sum;
    ll size, rev;
    void init(ll _value) {
        pre=ch[0]=ch[1]=NULL;
        min=value=sum=_value;
        lazy=rev=0;
        size=1;
    }
}mem[N];
ll memtop;
 
stack<node> S;
node root;
 
inline ll getsize(node &x) {
    return x ? x->size : 0;
}
 
void pushdown(node &x) {
    if (!x) return;
    if (x->lazy) {
        ll w = x->lazy;
        x->value += w;
        if (x->ch[0]) {
            x->ch[0]->lazy += w;
            x->ch[0]->min += w;
            x->ch[0]->sum += w*getsize(x->ch[0]);
        }
        if (x->ch[1]) {
            x->ch[1]->lazy += w;
            x->ch[1]->min += w;
            x->ch[1]->sum += w*getsize(x->ch[1]);
        }
        x->lazy = 0;
    }
    if (x->rev) {
        node t = x->ch[0];
        x->ch[0] = x->ch[1];
        x->ch[1] = t;
        x->rev = 0;
        if (x->ch[0]) x->ch[0]->rev ^= 1;
        if (x->ch[1]) x->ch[1]->rev ^= 1;
    }
}
 
void update(node &x) {
    if (!x) return;
    x->size = 1;
    x->min = x->value;
    x->sum = x->value;
    if (x->ch[0]) {
        x->sum += x->ch[0]->sum;
        x->min = min(x->min, x->ch[0]->min);
        x->size += x->ch[0]->size;
    }
    if (x->ch[1]) {
        x->sum += x->ch[1]->sum;
        x->min = min(x->min, x->ch[1]->min);
        x->size += x->ch[1]->size;
    }
}
 
void rotate(node &x, ll d) {
    node y = x->pre;
    pushdown(y);
    pushdown(x);
    pushdown(x->ch[d]);
    y->ch[!d] = x->ch[d];
    if (x->ch[d] != NULL) x->ch[d]->pre = y;
    x->pre = y->pre;
    if (y->pre != NULL)
        if (y->pre->ch[0] == y) y->pre->ch[0] = x; else y->pre->ch[1] = x;
    x->ch[d] = y;
    y->pre = x;
    update(y);
    if (y == root) root = x;
}
 
void splay(node &src, node &dst) {
    pushdown(src);
    while (src!=dst) {
        if (src->pre==dst) {
            if (dst->ch[0]==src) rotate(src, 1); else rotate(src, 0);
            break;
        }
        else {
            node y=src->pre, z=y->pre;
            if (z->ch[0]==y) {
                if (y->ch[0]==src) {
                    rotate(y, 1);
                    rotate(src, 1);
                }else {
                    rotate(src, 0);
                    rotate(src, 1);
                }
            }
            else {
                if (y->ch[1]==src) {
                    rotate(y, 0);
                    rotate(src, 0);
                }else {
                    rotate(src, 1);
                    rotate(src, 0);
                }
            }
            if (z==dst) break;
        }
        update(src);
    }
    update(src);
}
 
void select(ll k, node &f) {
    ll tmp;
    node t = root;
    while (1) {
        pushdown(t);
        tmp = getsize(t->ch[0]);
        if (k == tmp + 1) break;
        if (k <= tmp) t = t->ch[0];
        else {
            k -= tmp + 1;
            t = t->ch[1];
        }
    }
    pushdown(t);
    splay(t, f);
}
 
inline void selectsegment(ll l,ll r) {
    select(l, root);
    select(r + 2, root->ch[1]);
}
 
void insert(ll pos, ll value) {  //在pos位置后面插入一个新值value
    selectsegment(pos + 1, pos);
    node t;
    node x = root->ch[1];
    pushdown(root);
    pushdown(x);
    if (!S.empty()) {
        t = S.top();
        S.pop();
    } else {
        t = &mem[memtop++];
    }
    t->init(value);
    t->ch[1] = x;
    x->pre = t;
    root->ch[1] = t;
    t->pre = root;
    splay(x, root);
}
 
void add(ll a,ll b, ll value) {  //区间[a,b]中的数都加上value
    selectsegment(a, b);
    node x = root->ch[1]->ch[0];
    pushdown(x);
    update(x);
    x->min += value;
    x->lazy += value;
    splay(x, root);
}
 
void reverse(ll a, ll b) {   //区间[a,b]中的数翻转
    selectsegment(a, b);
    root->ch[1]->ch[0]->rev ^= 1;
    node x = root->ch[1]->ch[0];
    splay(x, root);
}
 
void revolve(ll a, ll b, ll t) { //区间[a,b]中的数向后循环移t位
    node p1, p2;
    selectsegment(a, b);
    select(b + 1 - t, root->ch[1]->ch[0]);
    p1 = root->ch[1]->ch[0];
    pushdown(p1);
    p2 = p1->ch[1];
    p1->ch[1] = NULL;
 
    select(a + 1, root->ch[1]->ch[0]);
    p1 = root->ch[1]->ch[0];
    pushdown(p1);
    p1->ch[0] = p2;
    p2->pre = p1;
 
    splay(p2, root);
}
 
ll getmin(ll a, ll b) {   //取[a,b]中最小的值
    selectsegment(a, b);
    node x = root->ch[1];
    pushdown(x);
    x = x->ch[0];
    pushdown(x);
    update(x);
    return x->min;
}
 
ll getsum(ll a, ll b) {
    selectsegment(a, b);
    node x = root->ch[1];
    pushdown(x);
    x = x->ch[0];
    pushdown(x);
    update(x);
    return x->sum;
}
 
void erase(ll pos) {               //抹去第pos个元素
    selectsegment(pos, pos);
    pushdown(root->ch[1]);
    S.push(root->ch[1]->ch[0]);        //回收内存
    root->ch[1]->ch[0] = NULL;
    node x = root->ch[1];
    splay(x, root);
}
 
 
void cutandmove(ll a,ll b,ll c)
{
    selectsegment(a,b);
    node CutRoot=root->ch[1]->ch[0];
    CutRoot->pre=NULL;
    root->ch[1]->size-=CutRoot->size;
    root->ch[1]->ch[0]=NULL;
    selectsegment(c+1,c);
 
    CutRoot->pre=root->ch[1];
    root->ch[1]->ch[0]=CutRoot;
    root->ch[1]->size+=CutRoot->size;
}
 
void cut(ll a,ll b)
{
    selectsegment(a,b);
    node CutRoot=root->ch[1]->ch[0];
    CutRoot->pre=NULL;
    root->size-=CutRoot->size;
    root->ch[1]->size-=CutRoot->size;
    root->ch[1]->ch[0]=NULL;
}
 
vector<ll> ans;
void inorder(node x)
{
    if (!x) return;
    pushdown(x);
    inorder(x->ch[0]);
    if (x->value!=inf) ans.push_back(x->value);
    inorder(x->ch[1]);
}
 
void initsplaytree(ll *a, ll n) {
    memtop = 0;
    root = &mem[memtop++];
    root->init(inf);
    root->ch[1] = &mem[memtop++];
    root->ch[1]->init(inf);
    while (!S.empty()) S.pop();
    for(ll i=0;i<n;i++) insert(i, a[i]);
}
 
ll v[N];
char op[10];
int main() {
    ll n, m;
    scanf("%lld", &n);
    for(ll i=0;i<n;i++) scanf("%lld", &v[i]);
    scanf("%lld", &m);
    initsplaytree(v, n);
    while (m--) {
        getchar();
        scanf("%s", &op);
        ll l, r, pos, d;
        if(op[0]=='A'){     //区间[a,b]中的数都加上value 
             scanf("%lld%lld%lld", &l, &r, &d);
             if(l>r) swap(l,r);
             add(l, r, d);
        }else if(op[0]=='I'){    //在pos位后插入一个数
             scanf("%lld %lld", &pos,&d);
             insert(pos, d);
        }else if(op[0]=='D'){   // 删除pos位置的数
             scanf("%lld", &pos);  
             erase(pos);
        }else if(op[0]=='M'){  //求[l,r]区间最值 
             scanf("%lld%lld", &l, &r);
             if(l>r) swap(l,r);
             printf("%lld\n",getmin(l,r));
        }else if(op[3]=='E'){   //区间[a,b]中的数翻转 
              scanf("%lld%lld", &l, &r);
              if(l>r) swap(l,r);
              reverse(l,r);
        }else{  //区间[a,b]中的数向后循环移t位 
             scanf("%lld%lld%lld", &l, &r, &d);
             if(l>r) swap(l,r);
             d%=(r-l+1);
             if(d!=0) revolve(l,r,d);
        }
    }
    return 0;
}

https://blog.csdn.net/bfk_zr/article/details/78635901

学是不可能学的了,记得打印带走。


添加新评论