0
点赞
收藏
分享

微信扫一扫

线段树及经典习题

mm_tang 2022-04-04 阅读 36

线段树(Segment Tree)

 

首先我们知道二叉树,balanced 二叉树可以保证查找的复杂度是logn的复杂度。理解起来十分直观:大概这样子:

之所以是平衡的,是因为它的构建方式,代码上能看出来。

如你所见,底层的就是单个的元素,每往上走一层,就会根据底层的元素做一个范围查询(rangeQuery)操作,这个操作可以是求和,求最小,求最大,同时上一层就会cover all the union of the children's range

there are four operations in total:

  1. build(start,end,vals)->O(n)
  2. update(index,value)->O(logn)
  3. rangeQuery(satrt,end)->O(logn+K) K is the number of the reported segments.  

那我们应该如何实现呢?

其实整体的写法跟快排有异曲同工之妙,也是递归不断的寻找中间点,如果左右满足结束要求,就结束。

当然首先我们要定义它的每一个节点的结构,这里start,end是左右闭区间的。

struct SegmentTreeNode 
{
	int start;
	int end;
	int sumnum;
	SegmentTreeNode* left;
	SegmentTreeNode* right;
	SegmentTreeNode(int s, int e, int n,SegmentTreeNode* l=NULL,SegmentTreeNode* r=NULL)
		:start(s), end(e), sumnum(n), left(l), right(r) {};
};

接下来我们借助递归来构建这棵树,当然如果是项目落地还是用用智能指针吧。实际上复杂度是2n啦,但是忽略掉2那就是n。如果是奇数,那么我们的构建方式保证左边多一个leaf。

SegmentTreeNode* buildTree(int start, int end, vector<int>& vals) 
{
	if (start == end) 
	{
		return new SegmentTreeNode(start, end, vals[start]);
	}
	int mid = start + (end - start) / 2;
	SegmentTreeNode* left = buildTree(start, mid, vals);
	SegmentTreeNode* right = buildTree(mid + 1, end, vals);
	return new SegmentTreeNode(start, end, left->sumnum + right->sumnum, left, right);
}

接下来是更新,更新的操作其实跟构建差不多,我们借助递归的方式,不断地更新当前的node,直到我们找到了start==end==index之后,,我们更新它的值并且返回,在返回的途中,我们不断更新它的父节点的值。忽略掉每一步的更新操作那就是logn。

void updateTree(SegmentTreeNode* root, int index, int val) 
{
	if (root->start == root->end && root->start == index) 
	{
		root->sumnum = val;
		return;
	}
	int mid = root->start + (root->end - root->start) / 2;
	if (index <= mid) 
	{
		updateTree(root->left, index, val);
	}
	else
	{
		updateTree(root->right, index, val);
	}
	root->sumnum = root->left->sumnum + root->right->sumnum;
}

 接下来看看范围查询,挺好理解的, 比较的时候记住什么时候加等号就行。比如val的length是5,那么一半就是2,为了平衡可能是01在左,234在右,根绝构建方式的不同就不同,这里我们的构建方式是012在做,34在右,因此当有边界小于等于2的时候就在左边,而不是小于。

int rangeQuery(SegmentTreeNode* root, int left, int right) 
{
	if (root->start == left && root->end == right) 
	{
		return root->sumnum;
	}
	int mid = root->start + (root->end - root->start) / 2;
	if (right <= mid) //完全落在左边
	{
		return rangeQuery(root->left, left, right);
	}
	else if (left > mid) //完全落在右边
	{
		return rangeQuery(root->right, left, right);
	}
	else //落在中间
	{
		return rangeQuery(root->left, left, mid) + rangeQuery(root->right, mid + 1, right);
	}
}

例题

力扣https://leetcode-cn.com/problems/range-sum-query-mutable/

 加个智能指针稍微改吧改吧:

class NumArray {
    struct SegmentTreeNode 
    {
        int start;
        int end;
        int sumnum;
        unique_ptr<SegmentTreeNode> left;
        unique_ptr<SegmentTreeNode> right;
        SegmentTreeNode(int s, int e, int n,SegmentTreeNode* l=NULL,SegmentTreeNode* r=NULL)
            :start(s), end(e), sumnum(n), left(l), right(r) {};
    };
    SegmentTreeNode* buildTree(int start, int end, vector<int>& vals) 
    {
        if (start == end) 
        {
            return new SegmentTreeNode(start, end, vals[start]);
        }
        int mid = start + (end - start) / 2;
        SegmentTreeNode* left = buildTree(start, mid, vals);
        SegmentTreeNode* right = buildTree(mid + 1, end, vals);
        return new SegmentTreeNode(start, end, left->sumnum + right->sumnum, left, right);
    }
    void updateTree(unique_ptr<SegmentTreeNode>& root, int index, int val) 
    {
        if (root->start == root->end && root->start == index) 
        {
            root->sumnum = val;
            return;
        }
        int mid = root->start + (root->end - root->start) / 2;
        if (index <= mid) 
        {
            updateTree(root->left, index, val);
        }
        else
        {
            updateTree(root->right, index, val);
        }
        root->sumnum = root->left->sumnum + root->right->sumnum;
    }
    int rangeQuery(unique_ptr<SegmentTreeNode>& root, int left, int right) 
    {
        if (root->start == left && root->end == right) 
        {
            return root->sumnum;
        }
        int mid = root->start + (root->end - root->start) / 2;
        if (right <= mid) //完全落在左边
        {
            return rangeQuery(root->left, left, right);
        }
        else if (left > mid) //完全落在右边
        {
            return rangeQuery(root->right, left, right);
        }
        else //落在中间
        {
            return rangeQuery(root->left, left, mid) + rangeQuery(root->right, mid + 1, right);
        }
    }
    unique_ptr<SegmentTreeNode> head;
public:
    NumArray(vector<int>& nums) {
        head=unique_ptr<SegmentTreeNode>(buildTree(0,nums.size()-1,nums));
    }
    
    void update(int index, int val) {
        updateTree(head,index,val);
    }
    
    int sumRange(int left, int right) {
        return rangeQuery(head,left,right);
    }
};

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray* obj = new NumArray(nums);
 * obj->update(index,val);
 * int param_2 = obj->sumRange(left,right);
 */

寄啊,真慢,凑活看…… 

举报

相关推荐

0 条评论