线段树

Author Avatar
tkandi 10月 23, 2017
  • 在其它设备中阅读本文章

前言

假如你有一个数组$a$,现要多次询问区间和$\sum_{i=l}^r{a_i}$。
如果是不带修改的,那么可以前缀和,$sum_i = \sum_{j = 1}^i{a_j}$,那么$\sum_{i = l}^r{a_i} = sum_r - sum_{l - 1}$。这是预处理$O(n)$,单次询问$O(1)$。
那么如果要支持修改呢?显然是不能用前缀和了。
这里给出一种思想,就是在询问之前对数组进行一些预处理,且是对任意询问是有用的,以来降低后面单个询问的复杂度。前面的前缀和其实就是这种思想简单的一种。



简介

这里介绍一种新的数据结构——线段树。
线段树(segment tree) 作为一种高级数据结构,其思想基础且简单。
线段树主要运用了二分思想。
以询问区间和为例,对于一个区间$[l, r]$,把它分成左右两段$[l, m]$和$[m + 1, r] \ (m = (l + r) >> 1)$,那么$\sum_{i = l}^r{a_i} = \sum_{i = l}^m{a_i} + \sum_{i = m + 1}^r{a_i}$。如果已经知道了这两个子区间的和,那么就可以求出这整个区间的和。至于如何知道这两个子区间的和,我们发现这个问题与原来的问题相当,只需要继续递归做就可以了。直到区间内只有一个元素,那么这个区间的和就是这个元素的值。仔细观察,我们如果把每个区间看做一个点,那么每个点最多有两个子节点,这是一棵二叉树,而且树的深度为$O(log_{2}n)$。且这不一棵是满二叉树,假设它是满二叉树,根据二叉树的性质,深度为$O(log_{2}n)$的满二叉树的节点数是$O(n)$的,所以线段树的节点数也是$O(n)$的。



基本操作

存储

线段树的存储有多种实现方法。

完全二叉树

对于每个节点,要记录它的所代表的区间的左端点和右端点,以及这个区间的信息。

struct Segment_Tree_Node {
    Segment_Tree_Node(int l = 0, int r = 0) : l(l), r(r) {
        x = Info();
        return;
    }
    int l, r;
    Info x;
};

当前点是$k$,那么它的左子节点(如果有)就是$k << 1$,它的右子节点(如果有)就是$k << 1 | 1$,它的父亲节点(如果有)就是$k >> 1$。可以证明用到节点的最大编号是小于或等于$4 \times n$的。具体证明,待填坑。。。

指针
struct Segment_Tree_Node {
    Segment_Tree_Node(int l = 0, int r = 0) : l(l), r(r) {
        lc = rc = NULL;
        x = Info();
        return;
    }
    int l, r;
    Segment_Tree_Node *lc, *rc;
    Info x;
};

指针有两种实现方法:

  • 内存动态分配,动态释放。
  • 先开一个内存池,往后用到不断分配,也可以用数组模拟。

线段树一共有$2 \times n$个节点,每个节点除了用完全二叉树存储的要记录的以外,还要记录它左右子节点的指针。

动态开点

对于一个修改或询问,与它有关的节点数为$O(log_{2}n)$,所以不用把所有的节点都开出来,一个节点只有被用到时才会开。空间复杂度一般为$O(mlog_{2}n)$,所以当$n$到达$10^9$或更大时也能做。

后文中未明确说明,一般用完全二叉树结构存储。



单点修改 && 区间查询

每次有两种操作:1. 把一个元素加上一个数;2. 询问区间和。

模板题目链接

预处理

记录下当前节点的所代表的区间的左端点和右端点,以及这个区间的信息(像这个例子,询问区间和就记录这个节点所代表的区间的和)。
若某个节点所代表的区间只有一个元素,那么这个区间的信息就由该元素算出。否则,它就有左右子节点,递归算出它们的信息,该节点的信息就由它们算出。
到这里预处理就结束了。

#define lc k << 1
#define rc k << 1 | 1

const int N = 100010;

int n;
int a[N];

struct Node {
    int l, r, sum;
} p[N << 2];

void Build(int k, int l, int r) {
    p[k].l = l;
    p[k].r = r;
    if (l == r) {
        p[k].sum = a[l];
        return;
    }
    int m = (l + r) >> 1;
    Build(lc, l, m);
    Build(rc, m + 1, r);
    p[k].sum = p[lc].sum + p[rc].sum;
    return;
}
修改

在当前节点所代表的区间修改在$x$位置上的元素加上$d$。
若当前节点是叶子节点,那么该节点肯定是$x$,把该节点的信息对$d$修改。
否则,若它在当前节点左子节点所代表的区间内,递归修改左子节点,否则递归修改右子节点,修改完后更新当前节点的信息。
由于线段树的深度为$O(log_{2}n)$,所以从根节点到某叶子节点的距离为$O(log_{2}n)$,所以时间复杂度为$O(log_{2}n)$。

void Modify(int k, int x, int d) {
    if (p[k].l == x && x == p[k].r) {
        p[k].sum += d;
        return;
    }
    if (x <= p[lc].r) Modify(lc, x, d);
    else Modify(rc, x, d);
    p[k].sum = p[lc].sum + p[rc].sum;
    return;
}
询问

在当前节点所代表的区间内询问在区间$[l, r]$中的元素的和。
若当前节点所代表的区间是区间$[l, r]$的一个子区间,那么直接返回当前节点所记录的和。
否则若区间$[l, r]$与当前节点左子节点所代表的区间有交集,那么答案加上递归询问左子节点的和,若区间$[l, r]$与当前节点右子节点所代表的区间有交集,那么答案加上递归询问右子节点的和。最后的总和就是答案。
一个区间最多被划分成线段树上的$O(log_{2}n)$个节点,所以单次询问的时间复杂度为$O(log_{2}n)$。具体证明,待填坑。。。

int Query(int k, int l, int r) {
    if (l <= p[k].l && p[k].r <= r) return p[k].sum;
    int sum = 0;
    if (l <= p[lc].r) sum += Query(lc, l, r);
    if (p[rc].l <= r) sum += Query(rc, l, r);
    return sum;
}

以下是主程序:

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i)
        scanf("%d", &a[i]);
    Build(1, 1, n);
    while (m--) {
        int t;
        scanf("%d", &t);
        if (t == 1) {
            int x, d;
            scanf("%d%d", &x, &d);
            Modify(1, x, d);
        } else {
            int l, r;
            scanf("%d%d", &l, &r);
            printf("%d\n", Query(1, l, r));
        }
    }
    return 0;
}



区间修改 && 单点查询

每次有两种操作:1. 把一个区间加上一个数;2. 询问一个元素的和。

模板题目链接

这种被介绍在树状数组中或许更好一点。
把原序列$a$差分成序列$b$,对于每一个修改操作$(l, r, x)$,把它转成两个操作,在差分数组的第$l$个元素上加上$x$,在差分数组的第$r + 1$个元素上加上$-x$。这样,如果要求原序列第$x$个元素的值,那么就是差分数组的前$x$个元素的和$\sum_{i = 1}^x{b_i}$。这样就把(区间修改 && 单点查询)转化成(单点修改 && 区间查询)了。

以下是主程序:

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) {
        scanf("%d", &b[i]);
        a[i] = b[i] - b[i - 1];
    }
    Build(1, 1, n);
    while (m--) {
        int t;
        scanf("%d", &t);
        if (t == 1) {
            int l, r, d;
            scanf("%d%d%d", &l, &r, &d);
            Modify(1, l, d);
            if (r < n) Modify(1, r + 1, -d);
        } else {
            int x;
            scanf("%d", &x);
            printf("%d\n", Query(1, 1, x));
        }
    }
    return 0;
}



区间修改 && 区间查询

每次有两种操作:1. 把一个区间加上一个数;2. 询问区间和。

模板题目链接

这里介绍一种 Lazy-Tag 思想。
在每个节点记录一个标记$lazy$,代表这个区间还没有下传的标记。例如,在区间加一个数中,记录的就是这个区间的所有子区间还要加的数的总和。

预处理

增加对每个节点$lazy$标记的预处理。

#define lc k << 1
#define rc k << 1 | 1

typedef long long LL;

const int N = 100010;

int n;
LL a[N];

struct Node {
    int l, r;
    LL sum, lazy;
} p[N << 2];

void Build(int k, int l, int r) {
    p[k].l = l;
    p[k].r = r;
    if (l == r) {
        p[k].sum = a[l];
        p[k].lazy = 0;
        return;
    }
    int m = (l + r) >> 1;
    Build(lc, l, m);
    Build(rc, m + 1, r);
    p[k].sum = p[lc].sum + p[rc].sum;
    return;
}
标记下传

把节点$k$的标记下传。
左子节点的和加上左子节点所代表的区间所含的元素个数$\times$当前节点的标记。左子节点的标记加上当前节点的标记。
右子节点的和加上右子节点所代表的区间所含的元素个数$\times$当前节点的标记。右子节点的标记加上当前节点的标记。
不要忘了把当前节点的标记清零。

inline void Pushdown(int k) {
    p[lc].sum += p[k].lazy * (p[lc].r - p[lc].l + 1);
    p[lc].lazy += p[k].lazy;
    p[rc].sum += p[k].lazy * (p[rc].r - p[rc].l + 1);
    p[rc].lazy += p[k].lazy;
    p[k].lazy = 0;
    return;
}
修改

在当前节点所代表的区间内将在区间$[l, r]$中的元素的加上$d$。
若当前节点所代表的区间是区间$[l, r]$的一个子区间,那么当前节点的和加上$d \times$当前区间的元素个数,$lazy$标记加上$d$。
否则下传当前节点的标记。若区间$[l, r]$与当前节点左子节点所代表的区间有交集,那么对左子节点进行修改,若区间$[l, r]$与当前节点右子节点所代表的区间有交集,那么对右子节点进行修改,修改完后更新当前节点的信息。

void Modify(int k, int l, int r, LL d) {
    if (l <= p[k].l && p[k].r <= r) {
        p[k].sum += d * (p[k].r - p[k].l + 1);
        p[k].lazy += d;
        return;
    }
    if (p[k].lazy) Pushdown(k);
    if (l <= p[lc].r) Modify(lc, l, r, d);
    if (p[rc].l <= r) Modify(rc, l, r, d);
    p[k].sum = p[lc].sum + p[rc].sum;
    return;
}
询问

在当前节点所代表的区间内询问在区间$[l, r]$中的元素的和。
若当前节点所代表的区间是区间$[l, r]$的一个子区间,那么直接返回当前节点所记录的和。
否则下传当前节点的标记。若区间$[l, r]$与当前节点左子节点所代表的区间有交集,那么答案加上递归询问左子节点的和,若区间$[l, r]$与当前节点右子节点所代表的区间有交集,那么答案加上递归询问右子节点的和。最后的总和就是答案。

LL Query(int k, int l, int r) {
    if (l <= p[k].l && p[k].r <= r) return p[k].sum;
    if (p[k].lazy) Pushdown(k);
    LL sum = 0;
    if (l <= p[lc].r) sum += Query(lc, l, r);
    if (p[rc].l <= r) sum += Query(rc, l, r);
    return sum;
}

以下是主程序:

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i)
        scanf("%lld", &a[i]);
    Build(1, 1, n);
    while (m--) {
        int t, l, r;
        scanf("%d%d%d", &t, &l, &r);
        if (t == 1) {
            LL d;
            scanf("%lld", &d);
            Modify(1, l, r, d);
        } else printf("%lld\n", Query(1, l, r));
    }
    return 0;
}

$Pushdown$是$O(1)$的,所以$lazy-tag$并不会影响复杂度。


</br>


扩展

可持久化线段树和主席树(函数式线段树)

可持久化线段树 就是线段树的可持久化。它有一个最初的版本,每次可在在某个版本的基础上修改并新增一个版本,且需要支持查询历史版本。

如果每次复制原来的线段树,再在新的线段树上修改,时间和空间复杂度都会到达$O(n \times m)$。这样高的复杂度是无法承受的。
我们观察后发现,线段树每次修改,只会修改线段树上$O(log_{2}n)$个节点,我们只需要新建这$O(log_{2}n)$个节点,其它的节点我们可以共用。这里就要用到动态开点了。这样时间复杂度和空间复杂度都降为$O((n + m)log_{2}n)$。

模板题目链接

预处理

现在是动态开点的,所以要开一个节点的内存池。节点信息中要加上左右子节点的指针(可以用数组模拟)。

const int N = 100010, M = 100010, P = 17, _INF = 0x80000000;

int n, m;
int a[N];

int tot = 0, root[M];
struct Node {
    int lc, rc, l, r, Max;
} p[(N << 2) + M * P];

int Build(int l, int r) {
    int k = ++tot;
    p[k].l = l;
    p[k].r = r;
    if (l == r) {
        p[k].Max = a[l];
        return k;
    }
    int m = (l + r) >> 1;
    p[k].lc = Build(l, m);
    p[k].rc = Build(m + 1, r);
    p[k].Max = max(p[p[k].lc].Max, p[p[k].rc].Max);
    return k;
}
修改

对于所有要修改的节点,再新建一个(从内存池中分配一个),对其进行修改。

int Modify(int k, int x, int d) {
    int K = ++tot;
    p[K] = p[k];
    if (p[k].l == x && x == p[k].r) {
        p[K].Max = d;
        return K;
    }
    if (x <= p[p[k].lc].r) p[K].lc = Modify(p[k].lc, x, d);
    else p[K].rc = Modify(p[k].rc, x, d);
    p[K].Max = max(p[p[K].lc].Max, p[p[K].rc].Max);
    return K;
}
询问

询问其实变化不大。

int Query(int k, int l, int r) {
    if (l <= p[k].l && p[k].r <= r) return p[k].Max;
    int Max = _INF;
    if (l <= p[p[k].lc].r) Max = max(Max, Query(p[k].lc, l, r));
    if (p[p[k].rc].l <= r) Max = max(Max, Query(p[k].rc, l, r));
    return Max;
}

以下是主程序:

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i)
        scanf("%d", &a[i]);
    int cnt = 0;
    root[cnt = 1] = Build(1, n);
    while (m--) {
        int t;
        scanf("%d", &t);
        if (!t) {
            int k, l, r;
            scanf("%d%d%d", &k, &l, &r);
            printf("%d\n", Query(root[k], l, r));
        } else {
            int k, x, d;
            scanf("%d%d%d", &k, &x, &d);
            root[++cnt] = Modify(root[k], x, d);
        }
    }
    return 0;
}



线段树合并

将两棵线段树合并成一棵新的线段树。

int Merge(int k1, int k2) {
    if (!k1) return k2;
    if (!k2) return k1;
    int k = ++tot;
    p[k].lc = Merge(p[k1].lc, p[k2].lc);
    p[k].rc = Merge(p[k1].rc, p[k2].rc);
    p[k].sum = p[p[k].lc].sum + p[p[k].rc].sum;
    return k;
}

开始时你有$n$棵线段树,每棵线段树都有一个节点,现在你不断把它们合并,直到只有一棵线段树,这个总的时间复杂度是$O(nlog_{2}n)$的。因为你每次合并都会使某些区间上所含的点变多,那么总复杂度即为线段树上所有节点所包含的元素数之和,线段树一共有$O(log_{2}n)$层,每层都有$O(n)$个节点,所以一共有$O(nlog_{2}n)$个节点,所以时间复杂度也是这个。

线段树分裂