线段树
线段树(segment tree)是一种高级数据结构, 专门用于在区间查询和区间更新场景中实现高效数据处理
定义
线段树是一种二叉搜索树, 也是平衡二叉树
它将一个区间划分成一些单元区间, 每个单元区间对应线段树中一个叶结点
(1) 每个节点表示一个区间
(2) 每个非叶子节点均有左右两颗子树, 对应区间左半与右半部分
根节点编号 $1$, 对于节点 $i$, 其左节点编号为 $2i$, 右节点编号为 $2i+1$
(3) 对于任意节点, 表示区间范围为$[x, y]$:
若 $x = y$, 则此为叶子节点
否则令 $mid = \lfloor {\frac{x+y}{2}} \rfloor$, 左儿子对于$[x, mid]$区间, 右儿子对应$[mid+1, y]$区间
- 示例, $n = 10$ 时线段树
节点 $1$, 管理范围为$[1, 10]$, 节点 $2$, 管理范围为$[1, 5]$, 节点 $12$, 管理范围为$[6, 7]$
$\cdots$
graph TB;
A1("1 [1, 10]")
A2("2 [1, 5]")
A3("3 [6, 10]")
A4("4 [1, 3]")
A5("5 [4, 5]")
A6("6 [6, 8]")
A7("7 [9, 10]")
A8("8 [1, 2]")
A9("9 [3, 3]")
A10("10 [4, 4]")
A11("11 [5, 5]")
A12("12 [6, 7]")
A13("13 [8, 8]")
A14("14 [9, 9]")
A15("15 [10, 10]")
A16("16 [1, 1]")
A17("17 [2, 2]")
A24("24 [6, 6]")
A25("25 [7, 7]")
A1-->A2
A2-->A4
A4-->A8
A8-->A16
A8-->A17
A4-->A9
A2-->A5
A5-->A10
A5-->A11
A1-->A3
A3-->A6
A6-->A12
A12-->A24
A12-->A25
A6-->A13
A3-->A7
A7-->A14
A7-->A15
特点
区间信息存储
线段树每个节点都存储一个区间信息, 如区间和、区间最小值或最大值等
平衡性
线段树是平衡二叉树, 因此其高度为$O(log n)$, 其中n是数组长度
保证线段树上操作(如查询和更新)时间复杂度都是$O(log n)$
高效性
线段树能够在$O(log n)$时间复杂度内完成查询和更新操作, 适用于处理静态或动态数组中区间问题
灵活性
线段树不仅支持单点更新, 还可以扩展为区间批量更新(通过懒标记优化)
同时, 线段树还可以处理更复杂区间问题, 如二维线段树用于处理二维平面中区间问题
操作
#include <iostream>
#include <vector>
#include <climits>
template<typename T>
class SegmentTree {
public:
SegmentTree(const vector<T>& arr) {
mSize = arr.size();
// 线段树大小是原数组大小4倍(最坏情况下满二叉树)
mTree.resize(4 * mSize);
build(1, 0, mSize - 1);
}
~SegmentTree() = default;
// 区间查询
T query(int x, int y) {
return query_util(1, 0, mSize - 1, x, y);
}
// 单点更新
void update(int idx, T val) {
// 更新函数也是从1开始, 与build函数保持一致
// 计算差值
int diff = val - arr[idx];
arr[idx] = val; // 更新原数组
// 更新线段树(递归)
update_util(1, 0, n - 1, idx, diff);
}
private:
std::vector<T> mTree;
int mSize;
// 构建线段树(递归)
void build(int node, int start, int end) {
if (start == end) {
// 叶节点, 直接存储数组元素
mTree[node] = arr[start];
return;
}
int mid = (start + end) / 2;
// 递归构建左子树
build(2 * node, start, mid);
// 递归构建右子树
build(2 * node + 1, mid + 1, end);
// 内部节点存储子树和
mTree[node] = mTree[2 * node] + mTree[2 * node + 1];
}
// 查询操作(递归)
int query_util(int node, int start, int end, int x, int y) {
if (y < start || end < x) {
// 查询区间与当前节点区间无交集
return 0;
}
if (x <= start && end <= y) {
// 查询区间完全包含当前节点区间
return mTree[node];
}
// 查询区间与当前节点区间有交集, 但不完全包含
int mid = (start + end) / 2;
int left_sum = query_util(2 * node, start, mid, x, y);
int right_sum = query_util(2 * node + 1, mid + 1, end, x, y);
return left_sum + right_sum;
}
// 单点更新辅助函数
void update_util(int node, int start, int end, int idx, T diff) {
if (start == end) {
// 叶节点, 直接更新
mTree[node] += diff;
return;
}
int mid = (start + end) / 2;
if (idx <= mid) {
update_util(2 * node, start, mid, idx, diff);
} else {
update_util(2 * node + 1, mid + 1, end, idx, diff);
}
}
};
int main() {
std::vector<int> arr = {1, 3, 5, 7, 9, 11};
SegmentTree segTree(arr);
std::cout << "Sum of values in given range [1, 3] = " << segTree.query(1, 3) << std::endl;
segTree.update(1, 10);
std::cout << "Sum of values in given range [1, 3] after update = " << segTree.query(1, 3) << std::endl;
return 0;
}