leetcode原题:307.区域和检索
不看题解根本不知道还有一个线段树
结构,从没见过这个概念。线段树
是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
-
线段树构造过程
如上图中有10个节点的线段树,根节点为全部区间,然后左节点为1-5,右节点为6-10;之后继续从中间裂变。所以,定义线段树结构:
struct Node { int beg; // 起始区间 int end; // 结束区间 int val; // 区间之和 Node *left; // 左子树 Node *right; // 右子树 Node() { beg = end = val = 0; left = right = nullptr; } };
在依据中间裂变方式定义构造线段树方法:
Node *buildTree(vector<int> &nums, int beg, int end) { if (beg > end) { return nullptr; } Node *p = new Node; p->beg = beg; p->end = end; if (beg == end) { p->val = nums[beg]; } else { p->left = buildTree(nums, beg, (beg + end) / 2); p->right = buildTree(nums, (beg + end) / 2 + 1, end); } return p; }
再构造线段树过程中,叶子节点之和即为该叶子节点对应值,而需要更新非叶子节点中和值。
int calcTree(Node *p) { if (p == nullptr) { return 0; } p->val += calcTree(p->left) + calcTree(p->right); return p->val; }
-
更新区间中某个值
更新区间内某个位置值,则需要遍历线段树,判断更新位置在节点所属区间内,再更新差值。void updateTree(Node *p, int index, int diff) { if (p == nullptr) { return; } if (p->beg <= index && p->end >= index) { p->val += diff; } if (p->beg > index || p->end < index) { return; } updateTree(p->left, index, diff); updateTree(p->right, index, diff); }
-
查找计算区间值
为了计算区间内之和,如区间[left, right]
,则可分别求出[1,right]
-[1, left]
。同样遍历线段树,并判断节点区间。如果节点都在位置之前,则加上整个节点之和。int getValue(Node *p, int index) { if (index < 0) { return 0; } if (index < p->beg) { return getValue(p->left, index); } else if (index > p->end) { return p->val + getValue(p->right, index); } else if (index == p->end) { return p->val; } else { int mid = (p->beg + p->end) / 2; if (index <= mid) { return getValue(p->left, index); } else { return p->left->val + getValue(p->right, index); } } }
-
完整实现如下所示
struct Node { int beg; int end; int val; Node *left; Node *right; Node() { beg = end = val = 0; left = right = nullptr; } }; class NumArray { public: NumArray(vector<int>& nums) { this->nums = nums; root = buildTree(nums, 0, nums.size() - 1); calcTree(root); } ~NumArray() { freeTree(root); } void update(int index, int val) { if (nums[index] == val) { return; } int diff = val - nums[index]; nums[index] = val; updateTree(root, index, diff); } int sumRange(int left, int right) { return getValue(root, right) - getValue(root, left - 1); } private: Node *buildTree(vector<int> &nums, int beg, int end) { if (beg > end) { return nullptr; } Node *p = new Node; p->beg = beg; p->end = end; if (beg == end) { p->val = nums[beg]; } else { p->left = buildTree(nums, beg, (beg + end) / 2); p->right = buildTree(nums, (beg + end) / 2 + 1, end); } return p; } int calcTree(Node *p) { if (p == nullptr) { return 0; } p->val += calcTree(p->left) + calcTree(p->right); return p->val; } void updateTree(Node *p, int index, int diff) { if (p == nullptr) { return; } if (p->beg <= index && p->end >= index) { p->val += diff; } if (p->beg > index || p->end < index) { return; } updateTree(p->left, index, diff); updateTree(p->right, index, diff); } int getValue(Node *p, int index) { if (index < 0) { return 0; } if (index < p->beg) { return getValue(p->left, index); } else if (index > p->end) { return p->val + getValue(p->right, index); } else if (index == p->end) { return p->val; } else { int mid = (p->beg + p->end) / 2; if (index <= mid) { return getValue(p->left, index); } else { return p->left->val + getValue(p->right, index); } } } void freeTree(Node *p) { if (p == nullptr) { return; } freeTree(p->left); freeTree(p->right); delete p; } private: Node *root; vector<int> nums; }; /** * Your NumArray object will be instantiated and called as such: * NumArray* obj = new NumArray(nums); * obj->update(index,val); * int param_2 = obj->sumRange(left,right); */
-
看下leetcode中推荐的题解,小丑还是我自己啊。不是说好的树结构,怎么就用个数组就搞定了。