数据结构之线段树 | 字数总计: 2.6k | 阅读时长: 10分钟 | 阅读量: |
线段树的定义 如果我们要求一个数组内任意区间的和 ,那么很容易想到用前缀和的方式去实现,每次计算任意区间和的时间复杂度都是。但是如果现在能对数组的任意一项进行修改,那么为了保证前缀和仍然有效,最坏情况下必须去更新前缀和数组的每一项,这样修改数据造成的时间复杂度是。所以线段树由此而生,它的目标也是求数组内任意区间的和 ,但对于数据的修改,它的时间复杂度只需要。
引用自百度百科:
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
简单来说,线段树是一棵二叉搜索树,它的每一个节点存储着一个区间的值,而叶子节点代表一个特殊区间,这个区间只包含一个值。下面是一个例子:
标准实现 回到上面的问题,我们要求一个数组内任意区间的和 ,并且数组可修改。我们构造一个NumArray
类,它的构造函数接收一个数组,其次是实现查询和修改2个方法,提供对数组的修改和对数组任意区间的查询:
1 2 3 4 5 6 class NumArray : def __init__ (self, nums: List[int ] ): def update (self, index: int , val: int ) -> None : def sumRange (self, left: int , right: int ) -> int:
现在我们来实现一棵线段树,首先构造一个线段树的节点类Node
。left
和right
分别代表它的左右子节点,l
和r
代表这个节点所指向的区间[l,r]
,v
表示这个区间的和,m
为区间的中间值,用于拆分区间:
1 2 3 4 5 6 7 8 class Node : def __init__ (self, l, r,v=0 ): self.left = None self.right = None self.l = l self.r = r self.m = (l + r) >> 1 self.v = v
接下来构造线段树类,同样它提供了修改和查询方法,同时实现了一个__build
的私有方法,它用于初始化线段树时,构造出整个线段树节点。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 class SegmentTree: def __init__ (self,nums) : self.nums = nums self.root = self.__build(0 ,len(nums) - 1 ) def __build (self,l,r) : if l == r: return Node(l,r,self.nums[l]) node = Node(l,r) node.left = self.__build(node.l,node.m) node.right = self.__build(node.m + 1 ,node.r) node.v = node.left.v + node.right.v return node def __modify (self, l, r, v, node) : if l > node.r or r < node.l: return if l < = node.l and node.r <= r: node.v = v return self.__modify(l, r, v, node.left) self.__modify(l, r, v, node.right) node.v = node.left.v + node.right.v def __query (self, l, r,node) : if l > node.r or r < node.l: return 0 if l < = node.l and node.r <= r: return node.v return self.__query(l, r, node.left) + self.__query(l, r, node.right) def modify (self,index,val) : self.__modify (index,index,val,self.root) def query (self,left,right) : return self.__query (left,right,self.root)
__modify
和__query
是查询和修改的逻辑,用于内部递归,同时对外提供了modify
和query
方法。这里的核心逻辑主要是两个if
判断:
if l > node.r or r < node.l
:表示查询区间超出节点区间范围,直接返回。对于查询方法来说则返回一个最终不影响结果的值,这里求和所以是0;
if l <= node.l and node.r <= r
:表示查询区间包含了完整的节点区间范围(是完全包含,非部分相交),对该节点的值进行处理;
其他则为相交情况,继续向下递归。
现在对于这个问题,我们只需要操作线段树即可:
1 2 3 4 5 6 7 8 9 class NumArray : def __init__ (self, nums: List[int ] ): self.tree=SegmentTree(nums) def update (self, index: int , val: int ) -> None : self.tree.modify(index,val) def sumRange (self, left: int , right: int ) -> int: return self.tree.query(left,right)
假设数组长度为10,修改索引1的递归路线如图:
递归查找范围在[1,1]
的节点,这里只会查找到叶子节点[1,1]
;
修改该节点的值,然后递归返回过程中一层层更新父节点的值;
查询区间[4,8]
,只有被[4,8]
完全包含的节点才返回它的值,如果节点不在范围内则返回0,不参与求和计算,否则继续递归,如果能走到叶子节点,这个叶子节点肯定是在范围内的。
其中参与求和计算的为边框为红色的节点[4,4]
,[5,7]
和[8,8]
。
数组实现 除了真实构建一棵线段树,我们也可以用数组来模拟线段树。假设nums
长度为N,则我们需要用4N长度来保存这棵线段树的所有节点。
我们以索引0作为根节点,那么假设某个节点在数组中的索引为i
,则它的左节点索引为2*i+1
,右节点索引为2*i+2
。
由于去除了节点Node
类,则每次递归时,需要带上节点的索引i
,以及当前节点的表示范围[left,right]
。对外接口依然不变。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 class SegmentTree : def __init__ (self,nums ): self.nums = nums self.arr = [0 ] * (4 * len (nums)) self.__build(0 ,len (self.nums) - 1 ,0 ) def __build (self,l,r,i ): if l == r: self.arr[i] = self.nums[l] return m = (l + r) >> 1 self.__build(l,m,i * 2 + 1 ) self.__build(m + 1 ,r,i * 2 + 2 ) self.arr[i] = self.arr[i * 2 + 1 ] + self.arr[i * 2 + 2 ] def __modify (self, l, r, v, left,right,i ): if l > right or r < left: return if l <= left and right <= r: self.arr[i] = v return mid = (left + right) >> 1 self.__modify(l, r, v, left,mid,i * 2 + 1 ) self.__modify(l, r, v, mid + 1 ,right,i * 2 + 2 ) self.arr[i] = self.arr[i * 2 + 1 ] + self.arr[i * 2 + 2 ] def __query (self, l, r, left,right,i ): if l > right or r < left: return 0 if l <= left and right <= r: return self.arr[i] mid = (left + right) >> 1 return self.__query(l, r, left,mid,i * 2 + 1 ) + self.__query(l, r, mid + 1 ,right,i * 2 + 2 ) def modify (self,index,val ): self.__modify(index,index,val,0 ,len (self.nums) - 1 ,0 ) def query (self,left,right ): return self.__query(left,right,0 ,len (self.nums) - 1 ,0 )
数组大小为什么是4N 由于线段树始终要保存数组中的所有元素值,而这些值都是存储在叶子节点,假设数组长度是N,那么叶子节点总数也是N。对于一棵满二叉树来说,假设它的叶子节点数是N,则它的节点总数为2N-1。然而这是理想情况,叶子节点并不能正好全部分布在同一层,比如上面N=10的情况。
为了保证所有节点都能存储在数组中而不越界,相当于为满二叉树再增加一层,节点总数变为2N-1+2N=4N-1
,所以开4N空间是肯定足够的。
线段树的动态开点 无论是构建树或者用数组模拟,都需要一开始就将内存空间分配好,使用动态开点则可以避免这一点,只有当用到了才分配空间,相当于懒加载。
在原有的线段树基础上增加了三个步骤:
节点Node
引入一个被称为懒标记的变量add
,它用于缓存子节点需修改的值;
增加一个__pushdown
方法,用于开辟子节点,并为子节点赋值,这个值就从add
传递过来;
增加一个__pushup
方法,用于动态开点结束后,修复当前节点的值。
既然是动态开点,那么一开始的__build
方法就不再需要了,只需构造一个根节点即可。修改后的线段树如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 class Node : def __init__ (self, l, r ): self.left = None self.right = None self.l = l self.r = r self.m = (l + r) >> 1 self.v = 0 self.add = 0 class SegmentTree : def __init__ (self,nums ): self.root = Node(0 ,len (nums)-1 ) def __modify (self, l, r, v, node ): if l > node.r or r < node.l: return if l <= node.l and node.r <= r: node.v = (node.r - node.l + 1 ) * v node.add = v return self.__pushdown(node) self.__modify(l, r, v, node.left) self.__modify(l, r, v, node.right) self.__pushup(node) def __query (self, l, r,node ): if l > node.r or r < node.l: return 0 if l <= node.l and node.r <= r: return node.v self.__pushdown(node) return self.__query(l, r, node.left) + self.__query(l, r, node.right) def __pushdown (self,node ): if node.left is None : node.left = Node(node.l, node.m) if node.right is None : node.right = Node(node.m + 1 , node.r) if node.add > 0 : node.left.v = (node.left.r - node.left.l + 1 ) * node.add node.right.v = (node.right.r - node.right.l + 1 ) * node.add node.left.add = node.add node.right.add = node.add node.add = 0 def __pushup (self,node ): node.v = node.left.v + node.right.v def modify (self,index,val ): self.__modify(index,index,val,self.root) def query (self,left,right ): return self.__query(left,right,self.root)
由于每个节点的值表示的是区间和,考虑到批量赋值的情况,这里乘以了区间的个数。当然本例不存在这样的情况,本例修改的总是叶子节点(区间大小为1的节点)。
虽然采用了动态开点,但是对于本例来说,初始化时就需要为每个叶子节点赋值,实际上相当于构建了一棵完整的线段树,所以这里采用动态开点并没有什么意义。
1 2 3 4 5 6 7 8 9 10 11 class NumArray : def __init__ (self, nums: List[int ] ): self.tree=SegmentTree(nums) for i,num in enumerate (nums): self.tree.modify(i,num) def update (self, index: int , val: int ) -> None : self.tree.modify(index,val) def sumRange (self, left: int , right: int ) -> int: return self.tree.query(left,right)
不过对于其他涉及到线段树的问题,比如区间范围很庞大,而查找和修改只针对部分区间,动态开点就很有帮助了。
本例出自力扣中等题307. 区域和检索 - 数组可修改 ,以下这些力扣的问题都可以通过线段树动态开点解决: