-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSegmentTree.py
111 lines (93 loc) · 5.45 KB
/
SegmentTree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from collections.abc import Callable
from functools import reduce
ListElemVal = int
class Segment:
def __init__(self, tree, left, right, val=None, left_child=None, right_child=None):
self.tree = tree
self.segment_left_idx: int = left
self.segment_right_idx: int = right
self.val: ListElemVal | None = None
self.left_child: Segment | None = left_child
self.right_child: Segment | None = right_child
def update_one(self, index: int, new_val: ListElemVal) -> None:
self.update_segment(index, index, new_val)
# 到底
def update_segment(self, left: int, right: int, new_val: ListElemVal) -> None:
# left right 可以覆盖并超过本区间的大小,但是必须是个有效的区间,且该区间和本segment有交集
assert (left <= right
and not (right < self.segment_left_idx or self.segment_right_idx < left)), \
"Invalid input left:{} right:{}".format(left, right)
if self.segment_left_idx == self.segment_right_idx:
self.val = new_val
return
if right <= self.left_child.segment_right_idx:
self.left_child.update_segment(left, right, new_val)
elif self.right_child.segment_left_idx <= left:
self.right_child.update_segment(left, right, new_val)
else:
self.left_child.update_segment(left, self.left_child.segment_right_idx, new_val)
self.right_child.update_segment(self.right_child.segment_left_idx, right, new_val)
self.val = reduce(self.tree.list_reducing_func,
map(self.tree.list_element_mapping_func,
[self.left_child.val, self.right_child.val]))
def query_one(self, index: int) -> ListElemVal:
return self.query_segment(index, index)
# 不到底
def query_segment(self, left: int, right: int) -> ListElemVal:
# left right 可以覆盖并超过本区间的大小,但是必须是个有效的区间,且该区间和本segment有交集
assert (left <= right
and not (right < self.segment_left_idx or self.segment_right_idx < left)), \
"Invalid input left:{} right:{}".format(left, right)
if self.segment_left_idx == self.segment_right_idx:
return self.val
if self.segment_left_idx == left and self.segment_right_idx == right:
return self.val
if right <= self.left_child.segment_right_idx:
return self.left_child.query_segment(left, right)
elif self.right_child.segment_left_idx <= left:
return self.right_child.query_segment(left, right)
else:
left_hand_side = self.left_child.query_segment(left, self.left_child.segment_right_idx)
right_hand_side = self.right_child.query_segment(self.right_child.segment_left_idx, right)
return reduce(self.tree.list_reducing_func,
map(self.tree.list_element_mapping_func,
[left_hand_side, right_hand_side]))
class SegmentTree:
def __init__(self,
elements: list[ListElemVal],
list_element_mapping_func: Callable[[ListElemVal], ListElemVal],
list_reducing_func: Callable[[ListElemVal, ListElemVal], ListElemVal]):
self.elements = elements.copy()
self.list_element_mapping_func = list_element_mapping_func
self.list_reducing_func = list_reducing_func
def build_tree(left: int, right: int) -> Segment:
seg = Segment(self, left, right)
if left == right:
seg.val = self.elements[left]
return seg
mid = (left + right) // 2
left_child = build_tree(left, mid)
right_child = build_tree(mid + 1, right)
seg.left_child = left_child
seg.right_child = right_child
seg.val = reduce(self.list_reducing_func,
map(self.list_element_mapping_func,
[seg.left_child.val, seg.right_child.val]))
return seg
self.root: Segment = build_tree(0, len(self.elements)-1)
# # SegmentTree 的意图:
# 一个固定长度的数组, 数组中的不定某个元素和不定某个区间数据发生更新动作,
# 给定任意原数组索引下标值index_left和index_right确定的一个区间
# 想用logN的复杂度查询该区间内的元素统计信息, 比如元素的和,最大元素值,最小元素值。
#
# 如果不借助任何数据结构,时间复杂度是n。
# 使用前缀和和哈希表可以做大常数时间复杂度,但是不允许数组元素发生变化。
# SegmentTree 的是个满二叉树,只有最后一行存在空Node, 所以我们可以使用数组存储,使用数组运算寻找孩子节点。
# 满二叉树意味着 SegmentTree能很方便的持久化
# (使用连个list,一个list存储segment_left_end_index, 一个存segment_right_end_index, 空Node在两个数组中都为nil)
# 源数组长度为n,最坏情况下,当n >= 3时,总的非空Node个数为 4n-5
# (最后一行有连个节点, 到第二行有n-2个叶子节点,1个非叶子节点, 可得最后一行node空间一共有2(n-1), 可得总节点数4n-5)
# 所以使用4n
# 这里的实现不使用数组存储了
# binary index tree 只能解决点查询、点更新问题
# segment tree 可以覆盖前者的所有能力,还能解决区间更新、区间查询的问题