线段树
31 December 2016

介绍

线段树,也叫区间树,英文名叫 segment tree,该数据结构主要是用来实现区间查询。

假设,我们要计算 arr[0] + arr[1] + arr[2] + ... + arr[9] ,总共需要计算 10 个数,要算 arr[0] + arr[1] + arr[2] + ... + arr[999] 也是类似,这个算法的规模是 O(n)。如果数组是无规律的,那么这个 O(n) 这个时间是无法优化的。但是,如果我们要查询多次的话,那么我们可以将结果缓存下来,那么下次查询就不用重新计算

sum[1, 10000]  = xxx;   第 1 个数到 第 10000 个数的和
sum[43, 999]   = xxx;   第 43 个数到 第 999 个数的和
sum[888, 1500] = xxx;   第 888 个数到 第 1500 个数的和
...

另外,涉及到缓存的地方就必定要考虑缓存的更新和删除

实现

在实现中,我们不能随便生成一个区间,区间的生成是由一开始的始末依次二分生成的。一图胜千言,下面是由区间 [1-10] 生成的线段树

一开始, 1, 10 分成 1, 56, 10 两半,之后以此类推。

现在我们来看看区间是怎么计算的:

要计算 1-10 区间的和,就先要计算 1-5,再加上 6-10;
要计算 1-5  区间的和,就先要计算 1-3,再加上 4-5;
要计算 1-3  区间的和,就先要计算 1-2,再加上 3-3;
...

看得出这个是一个递归。我们从根节点出发,计算到区间的左边界和右边界相等的时候,就表示我们走到了底部 ([1,1], [2,2], [3,3]...)

代码:

int data[11] = {-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};  // 除去 data[0]
int tree[100];

int left(int x) {         // 左子树对应的标号
    return 2*x;
}

int right(int x) {        // 右子树对应的标号
    return 2*x + 1;
}


// 构建线段树
void build(int root = 1, int l = 1, int r = 10) {
    if (l == r) {         // 如果当前的区间只有一个节点
        tree[root] = data[l];
        return;
    }

    if (l > r) {          // 越界,左边不能大于右边
        return;
    }

    int m = (l + r) / 2;

    build(left(root),  l,   m);
    build(right(root), m+1, r);

    tree[root] = tree[left(root)] + tree[right(root)]; // 往上迭代
}

查询

在构建完线段树之后,我们要查询区间 [query_min, query_max] 的和

区间的开始是从 [1, 10] 往中间缩。

  1. 如果我们要查询的是 1-10 的和,那么应该直接返回 tree[1]
  2. 如果我们要查询的是 1-8 的和,那么查询的过程是:
发现查询区间 [1-8] 不能完全包含 [l-r] (即[1-10])
--> 分成左右区间 [1-5] [6-10]
--> 发现查询区间 [1-8] 完全包含 [1-5]
--> 发现查询区间 [1-8] 部分包含 [6-10],将 6-10 分成 [6-8], [9-10]
----> 发现查询区间 [1-8] 部分包含 [6-8],并且 [9-10] 不在查询区间内

返回 sum([1-5]) + sum([6-8]) + sum([9, 10])
    15           21           0            ==> 36

  1. 如果我们要查询的是 3-7 的和,那么查询的过程是:
发现查询区间 [3-7] 不能完全包含 [l-r] ([1-10])
--> 分成左右区间 [1-5] [6-10]

    处理 [1-5]
----> 查询区间 [3-7] 不能完全包含 [1-5],分成 [1-3] 和 [4-5]

      处理 [1-3]
------> 查询区间 [3-7] 不能完全包含 [1-3],分成 [1-2], [3-3], --> [3-3] 完全包含

      处理 [4-5]
------> 查询区间 [3-7] 完全包含 [4-5]

    处理 [6-10]
----> 查询区间 [3-7] 不能完全包含 [6-10],分成 [6-8] 和 [9-10]

      处理 [6-8]
------> 查询区间 [3-7] 不能完全包含 [6-8],分成 [6-7], [8-8], --> [6-7] 完全包含

      处理 [9-10]
------> [9-10] 超出范围,返回 0

结果是 sum([3-3]) + sum([4-5]) + sum(6-7)

int getSum(int query_min, int query_max, int root = 1, int l = 1, int r = 10) {
    if (query_min <= l && r <= query_max) {
        return tree[root];
    }

    if (query_min > r || query_max < l) { // l, r 不在查询区间内
        return 0;
    }

    int m = (l + r) / 2;

    int lsum = getSum(query_min, query_max, left(root),  l,   m);
    int rsum = getSum(query_min, query_max, right(root), m+1, r);

    return lsum + rsum;
}

更新

更新相当于重建

void update(int index, int v) {
    data[index] = v;
    build();
}

懒惰更新

实际上,我们可以将对区间的更新推迟到我们需要获取这个区间的值的时候而不是马上重建整棵树,这个叫懒惰更新,lazy update

updateInterval(updateL, updateR, root, l, r)

lazy update 遵循以下原则:

  1. 如果 当前的节点范围 (l, r) 不在 更新范围 (updateL, updateR) 内,不做处理
  2. 如果 当前的节点有待更新的数据 ,将待更新的数据更新到当前节点。将 当前节点的更新信息 “推迟” 给左右两个子节点
  3. 如果 当前的节点范围 完全在 更新范围 内,更新当前的节点。将 待更新的信息 “推迟” 给左右两个子节点
  4. 如果 当前的节点范围 有一部分在 更新范围 内,那么遍历左右两个子节点
void updateInterval(int updateL, int updateR, int diff, int root = 1, int l = 1, int r = 9) {
    // 当前的节点有待更新的数据
    if (lazy[root] != 0) {
        tree[root] += (r-l+1) * lazy[root];

        if (l != r) {      // 存在子节点
            lazy[left(root)]  += lazy[root];
            lazy[right(root)] += lazy[root];
        }

        lazy[root] = 0;
    }

    // 不在更新范围内
    if (l > r || l > updateR || r < updateL) {
        return ;
    }

    // 当前的节点范围 完全在 更新范围 内
    if (l >= updateL && r <= updateR) {
        tree[root] += (r-l+1) * diff;

        if (l != r) {
            lazy[left(root)]  += diff;
            lazy[right(root)] += diff;
        }
        return;
    }

    // 当前的节点范围 有一部分在 更新范围
    int m = (l + r) / 2;
    updateInterval(updateL, updateR, diff, left(root),  l,   m);
    updateInterval(updateL, updateR, diff, right(root), m+1, r);

    tree[root] = tree[left(root)] + tree[right(root)];
}
int getSum(int query_min, int query_max, int root = 1, int l = 1, int r = 10) {
    // 当前的节点有待更新的数据
    if (lazy[root] != 0) {
        tree[root] += (r-l+1) * lazy[root];

        if (l != r) {      // 存在子节点
            lazy[left(root)]  += lazy[root];
            lazy[right(root)] += lazy[root];
        }

        lazy[root] = 0;
    }

    // 完全不包含
    if (l > r || l > query_max || r < query_min) {
        return 0;
    }

    // 全部包含
    if (query_min <= l && r <= query_max) {
        return tree[root];
    }

    int m = (l + r) / 2;
    int lsum = getSum(query_min, query_max, left(root),  l,   m);
    int rsum = getSum(query_min, query_max, right(root), m+1, r);

    return lsum + rsum;
}