kkogoro

树套树模板
树套树模板用于解决偏序问题Luogu P3380题面题目描述您需要写一种数据结构(可参考题目标题),来维护一个有序...
扫描右侧二维码阅读全文
16
2018/11

树套树模板

树套树模板

用于解决偏序问题

Luogu P3380

题面

题目描述

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 查询k在区间内的排名
  2. 查询区间内排名为k的值
  3. 修改某一位值上的数值
  4. 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
  5. 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)

输入格式

第一行两个数 n,m 表示长度为n的有序序列和m个操作

第二行有n个数,表示有序序列

下面有m行,opt表示操作标号

若opt=1 则为操作1,之后有三个数l,r,k 表示查询k在区间[l,r]的排名

若opt=2 则为操作2,之后有三个数l,r,k 表示查询区间[l,r]内排名为k的数

若opt=3 则为操作3,之后有两个数pos,k 表示将pos位置的数修改为k

若opt=4 则为操作4,之后有三个数l,r,k 表示查询区间[l,r]内k的前驱

若opt=5 则为操作5,之后有三个数l,r,k 表示查询区间[l,r]内k的后继

输出格式

对于操作1,2,4,5各输出一行,表示查询结果

输入样例

9 6
4 2 2 1 9 4 0 1 1
2 1 4 3
3 4 10
2 1 4 3
1 2 5 9
4 3 9 5
5 2 8 5

输出样例

2
4
3
4
9

树状数组套权值线段树法:

您应该在掌握树状数组,权值线段树和主席树之后再开始阅读下文

简要题解

如果没有修改操作,这道题可以使用主席树做,但是现在有了修改操作。

由于维护的信息满足区间减法,所以可以用前缀和相减来获得区间信息(类比主席树)。

所以如果不修改,维护的就是不带修改的前缀和;而现在有了修改,维护的就是带修改前缀和,在一维序列上这个可以使用树状数组来实现。

由于查的信息的特殊性,我们需要用权值线段树来维护具体信息(值域内有多少个值)。

综上所述,我们求的是带修改的权值线段树前缀。(说的肯定不标准,感性理解就好~)

这里的实现方法是在权值线段树外层套一个树状数组。

详细一点来说,树状数组的每个元素不是一个int型,而是一个权值线段树,一个权值线段树的modify操作就相当于改了树状数组的一个元素的值。

如果我们修改一个位置的值,在树状数组的元素是一个int型的时候是这样写的

void modify(int pos, int n, int delta) {
    for (int i = pos; i <= n; i += lowbit(i)) {
        data[i] += delta; //修改元素值的部分
    }
    return;
}

而现在树状数组的一个元素是权值线段树的时候,应该这样写

void modify(int pos, int n, int val, int delta) {
    for (int i = pos; i <= n; i += lowbit(i)) {
        SegmentTree::modify(root[i], MIN_L, MAX_R, val, delta);
    }
    return;
}

类比修改,我们可以推断出查询操作的大致框架

int query(int pos, int aim) {
    int ret = 0;
    for (int i = pos; i > 0; i -= lowbit(i)) {
        ret += SegmentTree::query(root[i], MIN_L, MAX_R, aim);
    }
    return;
}
操作1

查 $k$ 在 $[1,R]$ 的排名 减去 $k$ 在 $[1,L-1]$ 的排名即可。

操作2

$[1,R]$ 内对应的节点和$[1,L-1]$对应的节点同时进行主席树查k大的相应操作,只是同时维护多个节点,建议看代码理解。

操作3

前面已经给出了,只需要删除之前的值,加入现在的值就可以了。

操作4

查询 $k-1$ 的排名,如果为 $0$ 证明 $k$ 没前驱,否则输出排名为 $rank_{k-1}+1$ 的元素

操作5

查询 $k$ 的排名,如果为区间长度 $R-L+1$ ,则 $k$ 没有后继,否则输出排名为 $rank_{k} + 1$ 的元素

要注意我这里的排名是重复元素是计数的,即对于 $\{1, 2, 2, 3\}$ ,$2$ 的排名是 $3$ ,因此操作1的答案实际上是 $rank_{k - 1} + 1$ ,这里需要注意

代码

#include <iostream>
#include <algorithm>
#include <cstdio>


inline int read() {
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}


const int MAXN = 50007;


int que_1[MAXN], que_2[MAXN], tmp[MAXN << 1];


namespace SegmentTree {
    #define C(x) node[x].cnt
    #define L(x) node[x].lc
    #define R(x) node[x].rc
    struct Node {
        int lc, rc, cnt;    
    } node[MAXN << 7];
    int root[MAXN], cntNode;
    inline int newNode() {
        Node *it = &node[++cntNode];    
        it->lc = it->rc = it->cnt = 0;
        return cntNode;
    }
    void modify(int &cur, int l, int r, int aim, int delta) {
        if (cur == 0) cur = newNode();
        C(cur) += delta;
        if (l == r) return;
        int mid = (l + r) >> 1;
        if (aim <= mid) {
            modify(L(cur), l, mid, aim, delta);
        }
        else {
            modify(R(cur), mid + 1, r, aim, delta);    
        }
        return;
    }
    int queryRank(int cur, int l, int r, int aim) {
        if (cur == 0) return 0;
        if (l == r) return C(cur);
        int mid = (l + r) >> 1;
        if (aim <= mid) return queryRank(L(cur), l, mid, aim);
        else return C(L(cur)) + queryRank(R(cur), mid + 1, r, aim);
    }
    int queryKth(int tot_1, int tot_2, int l, int r, int k) {
        if (l == r) return l;
        int cnt = 0;
        for (int i = 1; i <= tot_1; ++i) {
            cnt += C(L(que_1[i]));    
        }
        for (int i = 1; i <= tot_2; ++i) {
            cnt -= C(L(que_2[i]));
        }
        int mid = (l + r) >> 1;
        if (cnt >= k) {
            for (int i = 1; i <= tot_1; ++i) {
                que_1[i] = L(que_1[i]);
            }
            for (int i = 1; i <= tot_2; ++i) {
                que_2[i] = L(que_2[i]);
            }
            return queryKth(tot_1, tot_2, l, mid, k);
        }
        else {
            for (int i = 1; i <= tot_1; ++i) {
                que_1[i] = R(que_1[i]);
            }
            for (int i = 1; i <= tot_2; ++i) {
                que_2[i] = R(que_2[i]);
            }    
            return queryKth(tot_1, tot_2, mid + 1, r, k - cnt);
        }
    }
}


namespace BIT {
    inline int lowbit(int x) {
        return x & (-x);
    }
    inline void modify(int pos, int n, int tot, int val, int delta) {
        for (int i = pos; i <= n; i += lowbit(i)) {
            SegmentTree::modify(SegmentTree::root[i], 1, tot, val, delta);
        }
        return;
    }
    inline int getRank(int l, int r, int val, int tot) {
        --l;
        int ret = 0;
        for (int i = r; i > 0; i -= lowbit(i)) {
            ret += SegmentTree::queryRank(SegmentTree::root[i], 1, tot, val);
        }
        for (int i = l; i > 0; i -= lowbit(i)) {
            ret -= SegmentTree::queryRank(SegmentTree::root[i], 1, tot, val);
        }
        return ret;
    }
    inline int getKth(int l, int r, int k, int tot) {
        --l;
        int tot_1 = 0, tot_2 = 0;
        for (int i = r; i > 0; i -= lowbit(i)) {
            que_1[++tot_1] = SegmentTree::root[i];
        }
        for (int i = l; i > 0; i -= lowbit(i)) {
            que_2[++tot_2] = SegmentTree::root[i];
        }
        return tmp[SegmentTree::queryKth(tot_1, tot_2, 1, tot, k)];
    }
}


int data[MAXN];


struct Option {
    int type, l, r, k;
} option[MAXN];


int main() {
    int n = read();
    int m = read();
    int tot = 0;
    for (int i = 1; i <= n; ++i) {
        data[i] = read();
        tmp[++tot] = data[i];
    }
    for (int i = 1; i <= m; ++i) {
        option[i].type = read();
        if (option[i].type == 3) {
            option[i].l = read();
            option[i].k = read();    
        }
        else {
            option[i].l = read();
            option[i].r = read();
            option[i].k = read();
        }
        
        if (option[i].type != 2) {
            tmp[++tot] = option[i].k;
        }
    }    
    
    std::sort(tmp + 1, tmp + 1 + tot);
    tot = std::unique(tmp + 1, tmp + 1 + tot) - tmp - 1;
    
    for (int i = 1; i <= n; ++i) {
        data[i] = std::lower_bound(tmp + 1, tmp + 1 + tot, data[i]) - tmp;
        BIT::modify(i, n, tot, data[i], 1);
    }
    
    for (int i = 1; i <= m; ++i) {
        switch (option[i].type) {
            case 1 : {
                option[i].k = std::lower_bound(tmp + 1, tmp + 1 + tot, option[i].k) - tmp;
                printf("%d\n", BIT::getRank(option[i].l, option[i].r, option[i].k - 1, tot) + 1);
                break;
            }
            case 2 : {
                printf("%d\n", BIT::getKth(option[i].l, option[i].r, option[i].k, tot));
                break;
            }
            case 3 : {
                BIT::modify(option[i].l, n, tot, data[option[i].l], -1);    
                data[option[i].l] = std::lower_bound(tmp + 1, tmp + 1 + tot, option[i].k) - tmp;
                BIT::modify(option[i].l, n, tot, data[option[i].l], 1);    
                break;
            }
            case 4 : {
                option[i].k = std::lower_bound(tmp + 1, tmp + 1 + tot, option[i].k) - tmp;
                int rank = BIT::getRank(option[i].l, option[i].r, option[i].k - 1, tot);
                if (rank == 0) printf("-2147483647\n");
                else printf("%d\n", BIT::getKth(option[i].l, option[i].r, rank, tot));
                break;
            }
            case 5 : {
                option[i].k = std::lower_bound(tmp + 1, tmp + 1 + tot, option[i].k) - tmp;
                int rank = BIT::getRank(option[i].l, option[i].r, option[i].k, tot);
                if (rank == option[i].r - option[i].l + 1) printf("2147483647\n");
                else printf("%d\n", BIT::getKth(option[i].l, option[i].r, rank + 1, tot));
            }
        }
    }
    return 0;
}

BZOJ3065带插入区间K小值

替罪羊套权值线段树

照着hzwer的写的

#include <iostream>
#include <cstdio>
#include <vector>


const int MAXN = 35007;
const int MAXM = 35007;
const int MAXV = 70000;


inline int read_int() {
    int x = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-') f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        x = x * 10 + ch - '0';
        ch = getchar();    
    }
    return x * f;
}
inline char read_char() {
    char ch = getchar();
    while (ch != 'Q' && ch != 'M' && ch != 'I') ch = getchar();
    return ch;
}


namespace SegmentTree {
#define L(x) node[x].ch[0]
#define R(x) node[x].ch[1]
#define S(x) node[x].sum
    static const int maxn = (MAXN + MAXM) * 16 * 16;
    struct Node {
        int ch[2], sum;
    } node[maxn];
    int rec_stack[maxn], top, cntNode;
    inline int newNode() {
        int ret = top ? rec_stack[top--] : ++cntNode;
        L(ret) = R(ret) = S(ret) = 0;
        return ret;
    }
    void delNode(int &x) {
        if (!x) return;
        delNode(L(x));
        delNode(R(x));
        rec_stack[++top] = x;
        x = 0;
        return;
    }
    void add(int &cur, const int &l, const int &r, const int &aim, const int &delta) {
        if (!cur) cur = newNode();
        S(cur) += delta;
        if (!S(cur)) return void(delNode(cur));
        if (l == r) return;
        const int &mid = (l + r) >> 1;
        if (aim <= mid) add(L(cur), l, mid, aim, delta);
        else add(R(cur), mid + 1, r, aim, delta);
        return;
    }
#undef L
#undef R
#undef S
}
int real_val[MAXN + MAXM];
const double alpha = 0.75;
struct ScapeGoatTree {
#define L(x) node[x].ch[0]
#define R(x) node[x].ch[1]
#define V(x) node[x].val
#define RT(x) node[x].rt
#define S(x) SegmentTree::node[ node[x].rt ].sum
    static const int maxn = MAXN + MAXM;
    struct Node {
        int ch[2], rt;
    } node[maxn];
    std::vector<int> tmp1, tmp2;
    int cntNode, que[maxn], tail, root, rebuild_node, rebuild_fa;
    inline void init() {
        rebuild_node = rebuild_fa = tail = root = cntNode = 0;
        return;
    }
    inline int newNode() {
        ++cntNode;
        L(cntNode) = R(cntNode) = RT(cntNode) = 0;
        return cntNode;
    }
    void DFS(const int &x) {
        if (!x) return;
        DFS(L(x));
        SegmentTree::delNode(RT(x));
        que[++tail] = x;
        DFS(R(x));
        return;
    }
    int build(const int &l, const int &r) {
        if (l > r) return 0 ;
        const int &mid = (l + r) >> 1;
        const int &cur = que[mid];
        for (int i = l; i <= r; ++i) 
            SegmentTree::add(RT(cur), 0, MAXV, real_val[ que[i] ], 1);
        L(cur) = build(l, mid - 1);
        R(cur) = build(mid + 1, r);
        return cur;
    }
    inline int Rebuild(const int &x) {
        tail = 0;
        DFS(x);
        return build(1, tail);
    }
    inline void maintain() {
        if (!rebuild_node) return;

        if (rebuild_fa) 
            node[rebuild_fa].ch[rebuild_node == R(rebuild_fa)] = Rebuild(rebuild_node);
        else 
            root = Rebuild(rebuild_node);

        rebuild_node = 0;
        return;
    }
    inline bool NeedRebuild(const int &x) {
        return S(x) * alpha < (double)std::max(S(L(x)), S(R(x)));
    }
    void insert(int &cur, const int &rank, const int &val, const int &fa) {
        if (!cur) {
            cur = newNode();
            SegmentTree::add(RT(cur), 0, MAXV, val, 1);
            real_val[cur] = val;
            return;
        }
        SegmentTree::add(RT(cur), 0, MAXV, val, 1);
        if (S(L(cur)) >= rank) insert(L(cur), rank, val, cur);
        else insert(R(cur), rank - S(L(cur)) - 1, val, cur);

        if (NeedRebuild(cur)) rebuild_node = cur, rebuild_fa = fa;
        return;
    }
    int modify(const int &cur, const int &rank, const int &val) {
        SegmentTree::add(RT(cur), 0, MAXV, val, 1);
        int del_val;
        if (S(L(cur)) >= rank) del_val = modify(L(cur), rank, val);
        else if (S(L(cur)) + 1 == rank) del_val = real_val[cur], real_val[cur] = val;
        else del_val = modify(R(cur), rank - S(L(cur)) - 1, val);
        SegmentTree::add(RT(cur), 0, MAXV, del_val, -1);
        return del_val;
    }
    void split(const int &cur, const int &l, const int &r) {
        const int &l_size = S(L(cur)), &cur_size = S(cur);    
        if (l == 1 && r == cur_size) return void(tmp1.push_back(RT(cur)));

        if (r <= l_size) split(L(cur), l, r);
        else if (l > l_size + 1) split(R(cur), l - l_size - 1, r - l_size - 1);
        else {
            tmp2.push_back(real_val[cur]);
            if (l <= l_size) split(L(cur), l, l_size);
            if (cur_size > l_size + 1) split(R(cur), 1, r - l_size - 1);
        }
        return;
    }
    int query(const int &l, const int &r, int k) {
        tmp1.clear(), tmp2.clear();
        split(root, l, r);
        int binary_l = 0, binary_r = MAXV;
        while (binary_l < binary_r) {
            const int &mid = (binary_l + binary_r) >> 1;
            int sum = 0;
            for (int i = 0; i < (int)tmp1.size(); ++i) 
                sum += SegmentTree::node[SegmentTree::node[tmp1[i]].ch[0]].sum;
            for (int i = 0; i < (int)tmp2.size(); ++i)
                if (tmp2[i] >= binary_l && tmp2[i] <= mid) ++sum;
            if (sum >= k) {
                for (int i = 0; i < (int)tmp1.size(); ++i)
                    tmp1[i] = SegmentTree::node[tmp1[i]].ch[0];
                binary_r = mid;
            }
            else {
                for (int i = 0; i < (int)tmp1.size(); ++i)
                    tmp1[i] = SegmentTree::node[tmp1[i]].ch[1];
                binary_l = mid + 1;
                k -= sum;
            }
        }
        return binary_l;
    }
#undef L
#undef R
#undef V
#undef RT
#undef S
} SGT;


int main() {
    int n = read_int();
    SGT.init();
    for (int i = 1; i <= n; ++i) {
        SGT.que[++SGT.tail] = SGT.newNode();
        real_val[ SGT.que[SGT.tail] ] = read_int();
    }
    SGT.root = SGT.build(1, SGT.tail);
    int q = read_int(), last_ans = 0;
    while (q--) {
        char opt = read_char();
        int x, y, k;
        x = read_int() ^ last_ans;
        y = read_int() ^ last_ans;
        switch (opt) {
            case 'Q' :
                k = read_int() ^ last_ans;
                printf("%d\n", last_ans = SGT.query(x, y, k));
                break;
            case 'M' :
                SGT.modify(SGT.root, x, y);
                break;
            case 'I' :
                SGT.insert(SGT.root, x - 1, y, 0);
                SGT.maintain();
                break;
        }
    }
    return 0;
}
Last modification:May 12th, 2019 at 12:01 am

One comment

  1. 码神

    实用,good

Leave a Comment