线段树的定义

如果我们要求一个数组内任意区间的和,那么很容易想到用前缀和的方式去实现,每次计算任意区间和的时间复杂度都是。但是如果现在能对数组的任意一项进行修改,那么为了保证前缀和仍然有效,最坏情况下必须去更新前缀和数组的每一项,这样修改数据造成的时间复杂度是。所以线段树由此而生,它的目标也是求数组内任意区间的和,但对于数据的修改,它的时间复杂度只需要

引用自百度百科:

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。

简单来说,线段树是一棵二叉搜索树,它的每一个节点存储着一个区间的值,而叶子节点代表一个特殊区间,这个区间只包含一个值。下面是一个例子:

标准实现

回到上面的问题,我们要求一个数组内任意区间的和,并且数组可修改。我们构造一个NumArray类,它的构造函数接收一个数组,其次是实现查询和修改2个方法,提供对数组的修改和对数组任意区间的查询:

1
2
3
4
5
6
class NumArray:
def __init__(self, nums: List[int]):

def update(self, index: int, val: int) -> None:

def sumRange(self, left: int, right: int) -> int:

现在我们来实现一棵线段树,首先构造一个线段树的节点类Nodeleftright分别代表它的左右子节点,lr代表这个节点所指向的区间[l,r]v表示这个区间的和,m为区间的中间值,用于拆分区间:

1
2
3
4
5
6
7
8
class Node:
def __init__(self, l, r,v=0):
self.left = None
self.right = None
self.l = l
self.r = r
self.m = (l + r) >> 1
self.v = v

接下来构造线段树类,同样它提供了修改和查询方法,同时实现了一个__build的私有方法,它用于初始化线段树时,构造出整个线段树节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class SegmentTree:
def __init__(self,nums):
self.nums = nums
self.root = self.__build(0,len(nums) - 1)

def __build(self,l,r):
if l == r:
return Node(l,r,self.nums[l])
node = Node(l,r)
node.left = self.__build(node.l,node.m)
node.right = self.__build(node.m + 1,node.r)
node.v = node.left.v + node.right.v
return node

def __modify(self, l, r, v, node):
if l > node.r or r < node.l:
return
if l <= node.l and node.r <= r:
node.v = v
return
self.__modify(l, r, v, node.left)
self.__modify(l, r, v, node.right)
node.v = node.left.v + node.right.v

def __query(self, l, r,node):
if l > node.r or r < node.l:
return 0
if l <= node.l and node.r <= r:
return node.v
return self.__query(l, r, node.left) + self.__query(l, r, node.right)

def modify(self,index,val):
self.__modify(index,index,val,self.root)

def query(self,left,right):
return self.__query(left,right,self.root)

__modify__query是查询和修改的逻辑,用于内部递归,同时对外提供了modifyquery方法。这里的核心逻辑主要是两个if判断:

  1. if l > node.r or r < node.l:表示查询区间超出节点区间范围,直接返回。对于查询方法来说则返回一个最终不影响结果的值,这里求和所以是0;
  2. if l <= node.l and node.r <= r:表示查询区间包含了完整的节点区间范围(是完全包含,非部分相交),对该节点的值进行处理;
  3. 其他则为相交情况,继续向下递归。

现在对于这个问题,我们只需要操作线段树即可:

1
2
3
4
5
6
7
8
9
class NumArray:
def __init__(self, nums: List[int]):
self.tree=SegmentTree(nums)

def update(self, index: int, val: int) -> None:
self.tree.modify(index,val)

def sumRange(self, left: int, right: int) -> int:
return self.tree.query(left,right)

假设数组长度为10,修改索引1的递归路线如图:

  • 递归查找范围在[1,1]的节点,这里只会查找到叶子节点[1,1]
  • 修改该节点的值,然后递归返回过程中一层层更新父节点的值;

查询区间[4,8],只有被[4,8]完全包含的节点才返回它的值,如果节点不在范围内则返回0,不参与求和计算,否则继续递归,如果能走到叶子节点,这个叶子节点肯定是在范围内的。

其中参与求和计算的为边框为红色的节点[4,4][5,7][8,8]

数组实现

除了真实构建一棵线段树,我们也可以用数组来模拟线段树。假设nums长度为N,则我们需要用4N长度来保存这棵线段树的所有节点。

我们以索引0作为根节点,那么假设某个节点在数组中的索引为i,则它的左节点索引为2*i+1,右节点索引为2*i+2

由于去除了节点Node类,则每次递归时,需要带上节点的索引i,以及当前节点的表示范围[left,right]。对外接口依然不变。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class SegmentTree:
def __init__(self,nums):
self.nums = nums
self.arr = [0] * (4 * len(nums)) # 用数组模拟
self.__build(0,len(self.nums) - 1,0)

def __build(self,l,r,i):
if l == r:
self.arr[i] = self.nums[l]
return
m = (l + r) >> 1
self.__build(l,m,i * 2 + 1)
self.__build(m + 1,r,i * 2 + 2)
self.arr[i] = self.arr[i * 2 + 1] + self.arr[i * 2 + 2]

def __modify(self, l, r, v, left,right,i):
if l > right or r < left:
return
if l <= left and right <= r:
self.arr[i] = v
return
mid = (left + right) >> 1
self.__modify(l, r, v, left,mid,i * 2 + 1)
self.__modify(l, r, v, mid + 1,right,i * 2 + 2)
self.arr[i] = self.arr[i * 2 + 1] + self.arr[i * 2 + 2]

def __query(self, l, r, left,right,i):
if l > right or r < left:
return 0
if l <= left and right <= r:
return self.arr[i]
mid = (left + right) >> 1
return self.__query(l, r, left,mid,i * 2 + 1) + self.__query(l, r, mid + 1,right,i * 2 + 2)

def modify(self,index,val):
self.__modify(index,index,val,0,len(self.nums) - 1,0)

def query(self,left,right):
return self.__query(left,right,0,len(self.nums) - 1,0)

数组大小为什么是4N

由于线段树始终要保存数组中的所有元素值,而这些值都是存储在叶子节点,假设数组长度是N,那么叶子节点总数也是N。对于一棵满二叉树来说,假设它的叶子节点数是N,则它的节点总数为2N-1。然而这是理想情况,叶子节点并不能正好全部分布在同一层,比如上面N=10的情况。

为了保证所有节点都能存储在数组中而不越界,相当于为满二叉树再增加一层,节点总数变为2N-1+2N=4N-1,所以开4N空间是肯定足够的。

线段树的动态开点

无论是构建树或者用数组模拟,都需要一开始就将内存空间分配好,使用动态开点则可以避免这一点,只有当用到了才分配空间,相当于懒加载。

在原有的线段树基础上增加了三个步骤:

  • 节点Node引入一个被称为懒标记的变量add,它用于缓存子节点需修改的值;
  • 增加一个__pushdown方法,用于开辟子节点,并为子节点赋值,这个值就从add传递过来;
  • 增加一个__pushup方法,用于动态开点结束后,修复当前节点的值。

既然是动态开点,那么一开始的__build方法就不再需要了,只需构造一个根节点即可。修改后的线段树如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class Node:
def __init__(self, l, r):
self.left = None
self.right = None
self.l = l
self.r = r
self.m = (l + r) >> 1
self.v = 0
self.add = 0 # 懒标记

class SegmentTree:
def __init__(self,nums):
self.root = Node(0,len(nums)-1) # 根节点表示最大区间

def __modify(self, l, r, v, node):
if l > node.r or r < node.l:
return
if l <= node.l and node.r <= r:
node.v = (node.r - node.l + 1) * v
node.add = v
return
self.__pushdown(node)
self.__modify(l, r, v, node.left)
self.__modify(l, r, v, node.right)
self.__pushup(node)

def __query(self, l, r,node):
if l > node.r or r < node.l:
return 0
if l <= node.l and node.r <= r:
return node.v
self.__pushdown(node)
return self.__query(l, r, node.left) + self.__query(l, r, node.right)

def __pushdown(self,node):
if node.left is None:
node.left = Node(node.l, node.m)
if node.right is None:
node.right = Node(node.m + 1, node.r)
if node.add > 0:
node.left.v = (node.left.r - node.left.l + 1) * node.add
node.right.v = (node.right.r - node.right.l + 1) * node.add
node.left.add = node.add
node.right.add = node.add
node.add = 0

def __pushup(self,node):
node.v = node.left.v + node.right.v

def modify(self,index,val):
self.__modify(index,index,val,self.root)

def query(self,left,right):
return self.__query(left,right,self.root)

由于每个节点的值表示的是区间和,考虑到批量赋值的情况,这里乘以了区间的个数。当然本例不存在这样的情况,本例修改的总是叶子节点(区间大小为1的节点)。

虽然采用了动态开点,但是对于本例来说,初始化时就需要为每个叶子节点赋值,实际上相当于构建了一棵完整的线段树,所以这里采用动态开点并没有什么意义。

1
2
3
4
5
6
7
8
9
10
11
class NumArray:
def __init__(self, nums: List[int]):
self.tree=SegmentTree(nums)
for i,num in enumerate(nums):
self.tree.modify(i,num)

def update(self, index: int, val: int) -> None:
self.tree.modify(index,val)

def sumRange(self, left: int, right: int) -> int:
return self.tree.query(left,right)

不过对于其他涉及到线段树的问题,比如区间范围很庞大,而查找和修改只针对部分区间,动态开点就很有帮助了。

本例出自力扣中等题307. 区域和检索 - 数组可修改,以下这些力扣的问题都可以通过线段树动态开点解决:

题目 难度 方法
307. 区域和检索 - 数组可修改 中等 线段树,树状数组
699. 掉落的方块 困难 线段树-动态开点
715. Range 模块 困难 线段树-动态开点
729. 我的日程安排表 I 中等 线段树-动态开点,二分查找
731. 我的日程安排表 II 中等 线段树-动态开点,差分数组
732. 我的日程安排表 III 困难 线段树-动态开点,差分数组