线段树(Segment Tree)
首先我们知道二叉树,balanced 二叉树可以保证查找的复杂度是logn的复杂度。理解起来十分直观:大概这样子:
之所以是平衡的,是因为它的构建方式,代码上能看出来。
如你所见,底层的就是单个的元素,每往上走一层,就会根据底层的元素做一个范围查询(rangeQuery)操作,这个操作可以是求和,求最小,求最大,同时上一层就会cover all the union of the children's range
there are four operations in total:
- build(start,end,vals)->O(n)
- update(index,value)->O(logn)
- rangeQuery(satrt,end)->O(logn+K) K is the number of the reported segments.
那我们应该如何实现呢?
其实整体的写法跟快排有异曲同工之妙,也是递归不断的寻找中间点,如果左右满足结束要求,就结束。
当然首先我们要定义它的每一个节点的结构,这里start,end是左右闭区间的。
struct SegmentTreeNode
{
int start;
int end;
int sumnum;
SegmentTreeNode* left;
SegmentTreeNode* right;
SegmentTreeNode(int s, int e, int n,SegmentTreeNode* l=NULL,SegmentTreeNode* r=NULL)
:start(s), end(e), sumnum(n), left(l), right(r) {};
};
接下来我们借助递归来构建这棵树,当然如果是项目落地还是用用智能指针吧。实际上复杂度是2n啦,但是忽略掉2那就是n。如果是奇数,那么我们的构建方式保证左边多一个leaf。
SegmentTreeNode* buildTree(int start, int end, vector<int>& vals)
{
if (start == end)
{
return new SegmentTreeNode(start, end, vals[start]);
}
int mid = start + (end - start) / 2;
SegmentTreeNode* left = buildTree(start, mid, vals);
SegmentTreeNode* right = buildTree(mid + 1, end, vals);
return new SegmentTreeNode(start, end, left->sumnum + right->sumnum, left, right);
}
接下来是更新,更新的操作其实跟构建差不多,我们借助递归的方式,不断地更新当前的node,直到我们找到了start==end==index之后,,我们更新它的值并且返回,在返回的途中,我们不断更新它的父节点的值。忽略掉每一步的更新操作那就是logn。
void updateTree(SegmentTreeNode* root, int index, int val)
{
if (root->start == root->end && root->start == index)
{
root->sumnum = val;
return;
}
int mid = root->start + (root->end - root->start) / 2;
if (index <= mid)
{
updateTree(root->left, index, val);
}
else
{
updateTree(root->right, index, val);
}
root->sumnum = root->left->sumnum + root->right->sumnum;
}
接下来看看范围查询,挺好理解的, 比较的时候记住什么时候加等号就行。比如val的length是5,那么一半就是2,为了平衡可能是01在左,234在右,根绝构建方式的不同就不同,这里我们的构建方式是012在做,34在右,因此当有边界小于等于2的时候就在左边,而不是小于。
int rangeQuery(SegmentTreeNode* root, int left, int right)
{
if (root->start == left && root->end == right)
{
return root->sumnum;
}
int mid = root->start + (root->end - root->start) / 2;
if (right <= mid) //完全落在左边
{
return rangeQuery(root->left, left, right);
}
else if (left > mid) //完全落在右边
{
return rangeQuery(root->right, left, right);
}
else //落在中间
{
return rangeQuery(root->left, left, mid) + rangeQuery(root->right, mid + 1, right);
}
}
例题
力扣https://leetcode-cn.com/problems/range-sum-query-mutable/
加个智能指针稍微改吧改吧:
class NumArray {
struct SegmentTreeNode
{
int start;
int end;
int sumnum;
unique_ptr<SegmentTreeNode> left;
unique_ptr<SegmentTreeNode> right;
SegmentTreeNode(int s, int e, int n,SegmentTreeNode* l=NULL,SegmentTreeNode* r=NULL)
:start(s), end(e), sumnum(n), left(l), right(r) {};
};
SegmentTreeNode* buildTree(int start, int end, vector<int>& vals)
{
if (start == end)
{
return new SegmentTreeNode(start, end, vals[start]);
}
int mid = start + (end - start) / 2;
SegmentTreeNode* left = buildTree(start, mid, vals);
SegmentTreeNode* right = buildTree(mid + 1, end, vals);
return new SegmentTreeNode(start, end, left->sumnum + right->sumnum, left, right);
}
void updateTree(unique_ptr<SegmentTreeNode>& root, int index, int val)
{
if (root->start == root->end && root->start == index)
{
root->sumnum = val;
return;
}
int mid = root->start + (root->end - root->start) / 2;
if (index <= mid)
{
updateTree(root->left, index, val);
}
else
{
updateTree(root->right, index, val);
}
root->sumnum = root->left->sumnum + root->right->sumnum;
}
int rangeQuery(unique_ptr<SegmentTreeNode>& root, int left, int right)
{
if (root->start == left && root->end == right)
{
return root->sumnum;
}
int mid = root->start + (root->end - root->start) / 2;
if (right <= mid) //完全落在左边
{
return rangeQuery(root->left, left, right);
}
else if (left > mid) //完全落在右边
{
return rangeQuery(root->right, left, right);
}
else //落在中间
{
return rangeQuery(root->left, left, mid) + rangeQuery(root->right, mid + 1, right);
}
}
unique_ptr<SegmentTreeNode> head;
public:
NumArray(vector<int>& nums) {
head=unique_ptr<SegmentTreeNode>(buildTree(0,nums.size()-1,nums));
}
void update(int index, int val) {
updateTree(head,index,val);
}
int sumRange(int left, int right) {
return rangeQuery(head,left,right);
}
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray* obj = new NumArray(nums);
* obj->update(index,val);
* int param_2 = obj->sumRange(left,right);
*/
寄啊,真慢,凑活看……