线段树

需求

  • 数组区间内经常发生修改,但是又要频繁求得区间的各类统计信息例如最大值、最小值、区间和等等。
  • 对于一个数组频繁求区间和和修改的情况下:
    • 普通操作:修改:O(1);区间和:O(n)
    • 前缀和:修改:O(n);区间和:O(1)
    • 线段树:修改:O(nlogn);区间和:O(nlogn)

线段树

  • 将数组以二分的形式建成一棵二叉树,叶子节点为数组的值,非叶子节点保存相应的统计信息。这棵树近似完全二叉树的结构,所以可以用数组来构建
  • 下面所列举的代码只需要改变给tree[node]赋值相关代码即可实现求最大值、最小值、区间和,肥肠方便
  • 307. 区域和检索 - 数组可修改

1、数组版

  • 粗略而言,数组版线段树的大小初始化为4N即可
  • 这里_buildTree, _update, _sumRange都传入了参数int start, int end。是通过形参来维护每个树节点的管理范围,事实上,也可以创建一个tree.shape == (3, 4N)的线段树,第一行是统计信息,第二三行是范围起点和终点,代码可见数组版v2
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    class NumArray {
    public:
    NumArray(vector<int>& nums):arr(nums), tree(arr.size()*4) {
    _buildTree(0, 0, arr.size() - 1);
    }

    void update(int index, int val) {
    _update(0, 0, arr.size() - 1, index, val);
    }

    int sumRange(int left, int right) {
    return _sumRange(0, 0, arr.size() - 1, left, right);
    }
    private:
    void _buildTree(int node, int start, int end) {
    if (start == end) {
    tree[node] = arr[start];
    return;
    }
    int mid = (start + end) / 2;
    int node_left = node * 2 + 1;
    int node_right = node * 2 + 2;
    _buildTree(node_left, start, mid);
    _buildTree(node_right, mid + 1, end);
    // 这里可以改为求最大值、最小值
    tree[node] = tree[node_left] + tree[node_right];
    }

    void _update(int node, int start, int end, int idx, int val) {
    if (start == end) {
    arr[idx] = val;
    tree[node] = val;
    return;
    }
    int mid = (start + end) / 2;
    int node_left = node * 2 + 1;
    int node_right = node * 2 + 2;
    if (idx <= mid) _update(node_left, start, mid, idx, val);
    else _update(node_right, mid + 1, end, idx, val);
    // 这里可以改为求最大值、最小值
    tree[node] = tree[node_left] + tree[node_right]; // 修改路径上的值,类似树状数组的寻祖
    }

    int _sumRange(int node, int start, int end, int l, int r) {
    if (l > end || r < start) return 0; // 不在区间内
    if (start >= l && end <= r) return tree[node];
    int mid = (start + end) / 2;
    int node_left = node * 2 + 1;
    int node_right = node * 2 + 2;
    // 这里可以改为求最大值、最小值
    return _sumRange(node_left, start, mid, l, r) + _sumRange(node_right, mid + 1, end, l, r);
    }
    vector<int> arr;
    vector<int> tree;
    };

2、真建树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
struct TNode {
int start;
int end;
int val;
TNode* left;
TNode* right;
};
class NumArray {
public:
NumArray(vector<int>& nums) {
tree = _buildTree(nums, 0, nums.size()-1);
}

void update(int index, int val) {
_update(tree, index, val);
}

int sumRange(int left, int right) {
return _sumRange(tree, left, right);
}
private:
TNode* _buildTree(vector<int>& arr, int l, int r) {
TNode* p = new TNode;
if (l == r) {
p->val = arr[l];
p->left = p->right = nullptr;
}
else {
int m = (l + r) / 2;
p->left = _buildTree(arr, l, m);
p->right = _buildTree(arr, m + 1, r);
p->val = p->left->val + p->right->val;
}
p->start = l; p->end = r;
return p;
}
void _update(TNode* root, int idx, int val) {
int start = root->start, end = root->end;
if (start == end) {
root->val = val;
return;
}
int mid = (start + end) / 2;
if (idx <= mid) _update(root->left, idx, val);
else _update(root->right, idx, val);
root->val = root->left->val + root->right->val;
}
int _sumRange(TNode* root, int l, int r) {
int start = root->start, end = root->end;
if (l > end || r < start) return 0; // 不在区间内
if (start >= l && end <= r) return root->val;
return _sumRange(root->left, l, r) + _sumRange(root->right, l, r);
}
TNode* tree;
};

3、数组版v2

  • 效率相较于v1是低了,占用内存也高了~
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    class NumArray {
    public:
    NumArray(vector<int>& nums):arr(nums), tree(3, vector<int>(arr.size() * 4)) {
    _buildTree(0, 0, arr.size() - 1);
    }

    void update(int index, int val) {
    _update(0, index, val);
    }

    int sumRange(int left, int right) {
    return _sumRange(0, left, right);
    }
    private:
    void _buildTree(int node, int start, int end) {
    if (start == end) {
    tree[0][node] = arr[start];
    tree[1][node] = start; // 赋值
    tree[2][node] = end; // 赋值
    return;
    }
    int mid = (start + end) / 2;
    int node_left = node * 2 + 1;
    int node_right = node * 2 + 2;
    _buildTree(node_left, start, mid);
    _buildTree(node_right, mid + 1, end);
    // 这里可以改为求最大值、最小值
    tree[0][node] = tree[0][node_left] + tree[0][node_right];
    tree[1][node] = start; // 赋值
    tree[2][node] = end; // 赋值
    }

    void _update(int node, int idx, int val) {
    int start = tree[1][node], end = tree[2][node];
    if (start == end) {
    arr[idx] = val;
    tree[0][node] = val;
    return;
    }
    int mid = (start + end) / 2;
    int node_left = node * 2 + 1;
    int node_right = node * 2 + 2;
    if (idx <= mid) _update(node_left, idx, val);
    else _update(node_right, idx, val);
    // 这里可以改为求最大值、最小值
    tree[0][node] = tree[0][node_left] + tree[0][node_right]; // 修改路径上的值,类似树状数组的寻祖
    }

    int _sumRange(int node, int l, int r) {
    int start = tree[1][node], end = tree[2][node];
    if (l > end || r < start) return 0; // 不在区间内
    if (start >= l && end <= r) return tree[0][node];
    int node_left = node * 2 + 1;
    int node_right = node * 2 + 2;
    // 这里可以改为求最大值、最小值
    return _sumRange(node_left, l, r) + _sumRange(node_right, l, r);
    }
    vector<int> arr;
    vector<vector<int>> tree;
    };

参考

  1. 正月点灯笼