0
点赞
收藏
分享

微信扫一扫

c++实现线段树结构

一条咸鱼的干货 2022-01-07 阅读 94

leetcode原题:307.区域和检索

不看题解根本不知道还有一个线段树结构,从没见过这个概念。线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。


  1. 线段树构造过程
    线段树构造
    如上图中有10个节点的线段树,根节点为全部区间,然后左节点为1-5,右节点为6-10;之后继续从中间裂变。

    所以,定义线段树结构:

    struct Node {
        int beg; // 起始区间
        int end; // 结束区间
        int val; // 区间之和
        Node *left; // 左子树
        Node *right; // 右子树
    
        Node() {
            beg = end = val = 0;
            left = right = nullptr;
        }
    };
    

    在依据中间裂变方式定义构造线段树方法:

        Node *buildTree(vector<int> &nums, int beg, int end) {
            if (beg > end) {
                return nullptr;
            }
            Node *p = new Node;
            p->beg = beg;
            p->end = end;
            if (beg == end) {
                p->val = nums[beg];
            } else {
                p->left = buildTree(nums, beg, (beg + end) / 2);
                p->right = buildTree(nums, (beg + end) / 2 + 1, end);
            }
            return p;
        }
    

    再构造线段树过程中,叶子节点之和即为该叶子节点对应值,而需要更新非叶子节点中和值。

        int calcTree(Node *p) {
            if (p == nullptr) {
                return 0;
            }
            p->val += calcTree(p->left) + calcTree(p->right);
            return p->val;
        }
    
  2. 更新区间中某个值
    更新区间内某个位置值,则需要遍历线段树,判断更新位置在节点所属区间内,再更新差值。

        void updateTree(Node *p, int index, int diff) {
            if (p == nullptr) {
                return;
            }
            if (p->beg <= index && p->end >= index) {
                p->val += diff;
            }
            if (p->beg > index || p->end < index) {
                return;
            }
            updateTree(p->left, index, diff);
            updateTree(p->right, index, diff);
        }
    
  3. 查找计算区间值
    为了计算区间内之和,如区间[left, right],则可分别求出[1,right] - [1, left]。同样遍历线段树,并判断节点区间。如果节点都在位置之前,则加上整个节点之和。

        int getValue(Node *p, int index) {
            if (index < 0) {
                return 0;
            }
            if (index < p->beg) {
                return getValue(p->left, index);
            } else if (index > p->end) {
                return p->val + getValue(p->right, index);
            } else if (index == p->end) {
                return p->val;
            } else {
                int mid = (p->beg + p->end) / 2;
                if (index <= mid) {
                    return getValue(p->left, index);
                } else {
                    return p->left->val + getValue(p->right, index);
                }
            }
        }
    
  4. 完整实现如下所示

    struct Node {
        int beg;
        int end;
        int val;
        Node *left;
        Node *right;
    
        Node() {
            beg = end = val = 0;
            left = right = nullptr;
        }
    };
    
    class NumArray {
    public:
        NumArray(vector<int>& nums) {
            this->nums = nums;
            root = buildTree(nums, 0, nums.size() - 1);
            calcTree(root);
        }
    
    	  ~NumArray() {
    	  	freeTree(root);
    	  }
    
        void update(int index, int val) {
            if (nums[index] == val) {
                return;
            }
            int diff = val - nums[index];
            nums[index] = val;
            updateTree(root, index, diff);
        }
    
        int sumRange(int left, int right) {
            return getValue(root, right) - getValue(root, left - 1);
        }
    
    private:
        Node *buildTree(vector<int> &nums, int beg, int end) {
            if (beg > end) {
                return nullptr;
            }
            Node *p = new Node;
            p->beg = beg;
            p->end = end;
            if (beg == end) {
                p->val = nums[beg];
            } else {
                p->left = buildTree(nums, beg, (beg + end) / 2);
                p->right = buildTree(nums, (beg + end) / 2 + 1, end);
            }
            return p;
        }
    
        int calcTree(Node *p) {
            if (p == nullptr) {
                return 0;
            }
            p->val += calcTree(p->left) + calcTree(p->right);
            return p->val;
        }
    
        void updateTree(Node *p, int index, int diff) {
            if (p == nullptr) {
                return;
            }
            if (p->beg <= index && p->end >= index) {
                p->val += diff;
            }
            if (p->beg > index || p->end < index) {
                return;
            }
            updateTree(p->left, index, diff);
            updateTree(p->right, index, diff);
        }
    
        int getValue(Node *p, int index) {
            if (index < 0) {
                return 0;
            }
            if (index < p->beg) {
                return getValue(p->left, index);
            } else if (index > p->end) {
                return p->val + getValue(p->right, index);
            } else if (index == p->end) {
                return p->val;
            } else {
                int mid = (p->beg + p->end) / 2;
                if (index <= mid) {
                    return getValue(p->left, index);
                } else {
                    return p->left->val + getValue(p->right, index);
                }
            }
        }
    
        void freeTree(Node *p) {
            if (p == nullptr) {
                return;
            }
            freeTree(p->left);
            freeTree(p->right);
            delete p;
        }
    
    private:
        Node *root;
        vector<int> nums;
    };
    
    /**
     * Your NumArray object will be instantiated and called as such:
     * NumArray* obj = new NumArray(nums);
     * obj->update(index,val);
     * int param_2 = obj->sumRange(left,right);
     */
    
    
  5. 看下leetcode中推荐的题解,小丑还是我自己啊。不是说好的树结构,怎么就用个数组就搞定了。

举报

相关推荐

0 条评论