0
点赞
收藏
分享

微信扫一扫

线段树(Segment tree)


线段树

线段树是一棵完满(Full)二叉树,储存区间 [线段],快速查询和更新区间值。

视频讲解为什么要开 4n 空间

线段树(Segment tree)_结点


线段树主要实现两个方法:「区间更新」&「区间查询」,时间复杂度均为 O(logn)。

线段树(Segment tree)_算法_02

建树

  1. 使用数组表示线段树,结点编号从 1 开始,结点 i,左结点为 2 * i ,右结点为 2 * i + 1。
  2. Node 类存储结点建树。

两种方式: 携带结点信息和通过传参传递信息。

public void build(Node node, int start, int end) {
    if (start == end) { // 到达叶子节点
        node.val = arr[end];
        return;
    }
    int mid = start + end >> 1;
    build(node.left, start, mid);
    build(node.right, mid + 1, end);
    pushUp(node);
}

// 向上更新 merge
private void pushUp(Node node) {
    node.val = node.left.val + node.right.val;
}

class Node {
    Node left, right; // 左右子结点    
    int val; // 当前结点值可能是和或最值
    // int start, end; // 一般通过传参获得
}

更新

「区间更新」,「点更新」是区间更新的特例。

「懒标记」

区间更新后,给对应的结点加懒标记,表示子结点有待更新。

class Node {    
    Node left, right;  
    int val;    
    int add; // 懒标记
}

「动态开点」,下推懒标记时,如果不存在左右子结点,那么先创建左右子结点。

// leftNum 和 rightNum 表示左右子区间的叶子结点数量
// 「加减」更新,懒标记 × 叶子结点的数量。
private void pushDown(Node node, int leftNum, int rightNum) {
    // 动态开点
    if (node.left == null) node.left = new Node();
    if (node.right == null) node.right = new Node();
    // 0,表示没有标记
    if (node.add == 0) return;
    // 当前结点加上标记值 × 该子树所有叶子结点的数量
    node.left.val += node.add * leftNum;
    node.right.val += node.add * rightNum;
    // 把标记下推给子结点,「加减」更新,下推懒标记时需要累加。
    node.left.add += node.add;
    node.right.add += node.add;
    // 取消当前结点标记
    node.add = 0;
}

// 在区间 [start, end] 中更新区间 [l, r] 的值,将区间 [l, r] + val
public void update(Node node, int start, int end, int l, int r, int val) {
    // 1、区间覆盖,二分 [start, end] 2、区间相等,二分 [l, r]。
    if (l <= start && end <= r) { 
        node.val += (end - start + 1) * val; // 区间包含的每个叶子结点 + val       
        node.add += val;  // 添加懒标记
        return;
    }
    int mid = start + end >> 1;
    // 下推标记
    // mid - start + 1:左子树叶子结点数,end - mid:右子树叶子结点数
    pushDown(node, mid - start + 1, end - mid);
    // [start, mid] 和 [l, r] 可能有交集
    if (l <= mid) update(node.left, start, mid, l, r, val);
    // [mid + 1, end] 和 [l, r] 可能有交集
    if (r > mid) update(node.right, mid + 1, end, l, r, val);
    // 向上更新
    pushUp(node);
}

查询

// 在区间 [start, end] 中查询区间 [l, r] 的结果,即 [l ,r] 保持不变
public int query(Node node, int start, int end, int l, int r) {
    // 区间 [l, r] 覆盖区间 [start, end] 二分区间 [start, end]
    if (l <= start && end <= r) return node.val;    
    int mid = start + end >> 1, ans = 0;
    // 下推标记
    pushDown(node, mid - start + 1, end - mid);
    // [start, mid] 和 [l, r] 可能有交集
    if (l <= mid) ans += query(node.left, start, mid, l, r);
    // [mid + 1, end] 和 [l, r] 可能有交集
    if (r > mid) ans += query(node.right, mid + 1, end, l, r);
    return ans;
}

动态开点

动态开点的优势在于,不需要事前构造空树,而是在插入操作 update 和查询操作 query 时根据访问需要进行「开点」操作。由于不保证查询和插入都是连续的,因此对于父结点 u 而言,不能通过 u << 1 和 u << 1 | 1 的固定方式进行访问,而要将结点 tr[u] 的左右结点所在 tr 数组的下标进行存储,分别记为 ls 和 rs 属性。对于 tr[u].left = 0 和 tr[u].right = 0 则是代表子结点尚未被创建,当需要访问到它们,则将其进行创建。

线段树的插入和查询都是 logn 的(如果涉及区间修改,则由懒标记来确保 logn 复杂度),因此在单次操作的时候,最多会创建数量级为 logn 的点,因此空间复杂度为 O(mlogn),而不是 O(4 * n),而开点数的预估需不能仅仅根据 logn 来进行,还要对具体常数进行分析,才能得到准确的点数上界。

动态开点是按需创建区间,如果是按照连续段进行查询或插入,最坏情况下仍然会占到 4 * n 的空间,因此盲猜 logn 的常数在 4 左右,保守一点可以直接估算到 6,因此可以估算点数为 6∗m∗logn,其中 n = 1e9 和 m = 1e4 分别代表值域大小和查询次数。

当然一个比较实用的估点方式可以「尽可能的多开点数」,利用题目给定的空间上界和创建的自定义类(结构体)的大小,尽可能的多开( Java 的 128M 可以开到 5 * 106 以上)。

线段树模版
注意:基于求「区间和」以及对区间进行「加减」的更新操作,且为「动态开点」。

// 线段树(动态开点)
public class SegmentTreeDynamic {
    class Node {
        Node left, right;
        int val, add;
    }
    private int N = (int) 1e9;
    private Node root = new Node();
    public void update(Node node, int start, int end, int l, int r, int val) {
        if (l <= start && end <= r) {
            node.val += (end - start + 1) * val;
            node.add += val;
            return ;
        }
        int mid = start + end >> 1;
        pushDown(node, mid - start + 1, end - mid);
        if (l <= mid) update(node.left, start, mid, l, r, val);
        if (r > mid) update(node.right, mid + 1, end, l, r, val);
        pushUp(node);
    }
    public int query(Node node, int start, int end, int l, int r) {
        if (l <= start && end <= r) return node.val;
        int mid = start + end >> 1, ans = 0;
        pushDown(node, mid - start + 1, end - mid);
        if (l <= mid) ans += query(node.left, start, mid, l, r);
        if (r > mid) ans += query(node.right, mid + 1, end, l, r);
        return ans;
    }
    private void pushUp(Node node) {
        node.val = node.left.val + node.right.val;
    }
    private void pushDown(Node node, int leftNum, int rightNum) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
        if (node.add == 0) return ;
        node.left.val += node.add * leftNum;
        node.right.val += node.add * rightNum;
        node.left.add += node.add;
        node.right.add += node.add;
        node.add = 0;
    }
}

「线段树」应用范围:

数组不变,求区间和:「前缀和」、「树状数组」、「线段树」
多次修改某个数(单点),求区间和:「树状数组」、「线段树」
多次修改某个区间,输出最终结果:「差分」
多次修改某个区间,求区间和:「线段树」、「树状数组」(看修改区间范围大小)
多次将某个区间变成同一个数,求区间和:「线段树」、「树状数组」(看修改区间范围大小)

应该按这样的优先级进行考虑:

简单求区间和,用「前缀和」
多次将某个区间变成同一个数,用「线段树」
其他情况,用「树状数组」

303. 区域和检索 - 数组不可变

class NumArray {
    int[] prefixSum;
    public NumArray(int[] nums) {
        int n = nums.length;
        prefixSum = new int[n + 1];
        for(int i = 0; i < n; i++){
            prefixSum[i + 1] = prefixSum[i] + nums[i];
        }        
    }
    
    public int sumRange(int left, int right) {
        return prefixSum[right + 1] - prefixSum[left];
    }
}

307. 区域和检索 - 数组可修改

单点更新,区间求和。

class NumArray {
    int[] arr, tree;

    public NumArray(int[] nums) {
        int n = nums.length;
        arr = nums;
        tree = new int[n * 4];
        build(1, 0, n - 1); // 从 1 开始
    }

    public void update(int index, int val) {
        update(1, 0, arr.length - 1, index, val);
    }

    // 建树 结点(区间 [start, end])对应下标,通过传参获得
    void build(int node, int start, int end) {
        // 叶子结点的值为对应数组的元素
        if (start == end) tree[node] = arr[end];
        else {
            int mid = start + end >> 1; // 二分区间递归建树
            int leftNode = node << 1; // 左子树
            int rightNode = node << 1 | 1;
            build(leftNode, start, mid);
            build(rightNode, mid + 1, end);
            pushUp(node, leftNode, rightNode); // 回溯上推区间和
        }
    }

    private void pushUp(int node, int leftNode, int rightNode) {
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    // 单点更新  index 是数组下标,对应 叶子结点 区间的边界,node 是 tree 编号。
    void update(int node, int start, int end, int index, int val) {
        if (start == end) { // start = end = index
            tree[node] = val; // 更新叶子结点,即区间和。
            // arr[index] = val; // 更新数组元素
        } else {
            int mid = start + end >> 1;
            int leftNode = node << 1;
            int rightNode = node << 1 | 1;
            if (index <= mid) update(leftNode, start, mid, index, val);
            else update(rightNode, mid + 1, end, index, val);
            pushUp(node, leftNode, rightNode);
        }
    }

    public int sumRange(int left, int right) {
        return query(1, 0, arr.length - 1, left, right);
    }

    // 区间查询
    int query(int node, int start, int end, int L, int R) {
        if (R < start || end < L) return 0;
        if (L <= start && end <= R) return tree[node]; // 覆盖区间
        int mid = start + end >> 1;
        int leftNode = node << 1;
        int rightNode = node << 1 | 1;
        int sumLeft = query(leftNode, start, mid, L, R);
        int sumRight = query(rightNode, mid + 1, end, L, R);
        return sumLeft + sumRight;
    }
}

Node 类实现树

class NumArray {
    int[] arr;
    Node root = new Node();

    public NumArray(int[] nums) {
        int n = nums.length;
        arr = nums;
        build(root, 0, n - 1);
    }

    public void update(int index, int val) {
        update(root, 0, arr.length - 1, index, val);
    }

    public int sumRange(int left, int right) {
        return query(root, 0, arr.length - 1, left, right);
    }

    // 建树
    void build(Node node, int start, int end) {
        if (start == end) {
            node.val = arr[end];
            return;
        }
        int mid = start + end >> 1;
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
        build(node.left, start, mid);
        build(node.right, mid + 1, end);
        node.val = node.left.val + node.right.val;
    }

    // 单点更新 start, end, index 指下标
    void update(Node node, int start, int end, int index, int val) {
        if (start == end) {
            node.val = val;
            return;
        }
        int mid = start + end >> 1;
        if (index <= mid) update(node.left, start, mid, index, val);
        else update(node.right, mid + 1, end, index, val);
        node.val = node.left.val + node.right.val;
    }

    // 区间查询
    int query(Node node, int start, int end, int L, int R) {
        if (R < start || end < L) return 0;
        if (L <= start && end <= R) return node.val;
        int mid = start + end >> 1;
        int sumLeft = query(node.left, start, mid, L, R);
        int sumRight = query(node.right, mid + 1, end, L, R);
        return sumLeft + sumRight;
    }
}

class Node {
    Node left, right;
    int val;
}

933. 最近的请求次数

class RecentCounter {
    Deque<Integer> q;
    public RecentCounter() {
        q = new ArrayDeque();
    }
    
    public int ping(int t) {
        q.offer(t);
        while(q.peek() < t - 3000) q.poll();
        return q.size();
    }
}

存储所有请求,查询最近的请求次数,涉及 「单点修改」「区间查询」
t 的数据范围为 1e9,调用次数是 1e4, 「强制在线」 无法 「离散化」 解决 MLE 问题,使用 「线段树(动态开点)」MLE(Memory Limit Exceeded) 超出空间限制。

class RecentCounter {
    class Node { // 携带区间信息
        // left 和 right 分别代表当前结点(区间)的左右子结点在 tree 的下标
        // val 代表在当前结点(区间)所包含的数的个数
        int left, right, val;
    }

    int N = (int) 1e9, M = 800010, idx = 1;
    Node[] tree = new Node[M]; // 用数组动态开点建树

    void update(int u, int start, int end, int x, int v) {
        if (start == x && end == x) {
            tree[u].val += v;
            return;
        }
        lazyCreate(u);
        int mid = start + end >> 1;
        if (x <= mid) update(tree[u].left, start, mid, x, v); // 覆盖
        else update(tree[u].right, mid + 1, end, x, v);
        pushup(u);
    }

    int query(int u, int start, int end, int l, int r) {
        if (l <= start && end <= r) return tree[u].val;
        lazyCreate(u);
        int mid = start + end >> 1, ans = 0;
        if (l <= mid) ans = query(tree[u].left, start, mid, l, r);
        if (r > mid) ans += query(tree[u].right, mid + 1, end, l, r);
        return ans;
    }

    void lazyCreate(int u) {
        if (tree[u] == null) tree[u] = new Node();
        if (tree[u].left == 0) {
            tree[u].left = ++idx;
            tree[tree[u].left] = new Node();
        }
        if (tree[u].right == 0) {
            tree[u].right = ++idx;
            tree[tree[u].right] = new Node();
        }
    }

    void pushup(int u) {
        tree[u].val = tree[tree[u].left].val + tree[tree[u].right].val;
    }

    public int ping(int t) {
        update(1, 1, N, t, 1);
        return query(1, 1, N, Math.max(0, t - 3000), t);
    }
}

时间复杂度:令 ping 的调用次数为 m,值域大小为 n,线段树的插入和查询复杂度均为 O(logn)
空间复杂度:O(m∗logn)

class RecentCounter {

    public RecentCounter() {
    }
    
    public int ping(int t) {
        update(root, 1, N, t, t, 1); // 区间(单点)更新
        return query(root, 1, N, Math.max(0, t - 3000), t); // 区间查询
    }

    int N = (int) 1e9;
    Node root = new Node();

    void update(Node node, int start, int end, int l, int r, int val) {
        if (l <= start && end <= r) {
            node.val += val;
            node.add += val;
            return;
        }
        int mid = start + end >> 1;
        pushDown(node, mid - start + 1, end - mid);
        if (l <= mid) update(node.left, start, mid, l, r, val);
        if (r > mid) update(node.right, mid + 1, end, l, r, val);
        pushUp(node);
    }

    public int query(Node node, int start, int end, int l, int r) {
        if (l <= start && end <= r) return node.val;
        int mid = start + end >> 1, ans = 0;
        pushDown(node, mid - start + 1, end - mid);
        if (l <= mid) ans = query(node.left, start, mid, l, r);
        if (r > mid) ans += query(node.right, mid + 1, end, l, r);
        return ans;
    }

    private void pushUp(Node node) {
        node.val = node.left.val + node.right.val;
    }

    private void pushDown(Node node, int leftNum, int rightNum) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
        if (node.add == 0) return;
        node.left.val += node.add * leftNum;
        node.right.val += node.add * rightNum;
        node.left.add += node.add;
        node.right.add += node.add;
        node.add = 0;
    }

    class Node {
        Node left, right; // 左右子结点        
        int val, add; // 当前结点值,以及懒标记
    }
}

699. 掉落的方块

线段树(动态开点)的两种方式

每次从插入操作都附带一次询问,因此询问次数为 1e3,左端点的最大值为 1e8,边长最大值为 1e6,由此可知值域范围大于 1e8,但不超过 1e9。

对于值域范围大,但查询次数有限的区间和,一般要么使用 「离散化 + 线段树」,要么使用 「线段树(动态开点)」 进行求解。

本题为「非强制在线」问题,因此可以先对 ps 数组进行离散化,将值域映射到较小的空间,然后套用固定占用 4×n 空间的线段树求解。但更为灵活(能够同时应对强制在线问题)的求解方式是「线段树(动态开点)」。

将顺序放置方块的操作(假设当前方块的左端点为 a,边长为 len,则有右端点为 b = a + len),分成如下两步进行:

查询当前范围 [a, b] 的最大高度 cur;
更新当前范围 [a, b] 的最新高度为 cur + len。
「区间修改 + 区间查询」,需要实现带「懒标记」,确保在进行「区间修改」时复杂度仍为 O(logn)。

另外有一个需要注意的细节是:不同方块之间的边缘可以重合,但不会导致方块叠加,因此当对一个区间 [a, b] 进行操作(查询或插入)时,可以将其调整为 [a, b - 1],从而解决边缘叠加操作高度错误的问题。

class Solution {
    public List<Integer> fallingSquares(int[][] positions) {
        List<Integer> ans = new ArrayList<>();
        for (int[] position : positions) {
            int x = position[0], h = position[1];
            // 先查询出 [x, x + h] 的值
            int cur = query(root, 0, N, x, x + h - 1);
            // 更新 [x, x + h - 1] 为 cur + h 
            update(root, 0, N, x, x + h - 1, cur + h);
            ans.add(root.val);
        }
        return ans;
    }

    int N = (int) 1e9;
    Node root = new Node();

    // 区间更新 加懒标记
    void update(Node node, int start, int end, int l, int r, int val) {
        if (l <= start && end <= r) {
            node.val = val;
            node.add = val;
            return;
        }
        pushDown(node);
        int mid = start + end >> 1;
        if (l <= mid) update(node.left, start, mid, l, r, val);
        if (r > mid) update(node.right, mid + 1, end, l, r, val);
        pushUp(node);
    }

    // 区间查询
    public int query(Node node, int start, int end, int l, int r) {
        if (l <= start && end <= r) return node.val;
        pushDown(node);
        int mid = start + end >> 1, ans = 0;
        if (l <= mid) ans = query(node.left, start, mid, l, r);
        if (r > mid) ans = Math.max(ans, query(node.right, mid + 1, end, l, r));
        return ans;
    }

    private void pushUp(Node node) {
        // 每个结点存的是当前区间的最大值
        node.val = Math.max(node.left.val, node.right.val);
    }

    // 动态开点,下推懒标记
    private void pushDown(Node node) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
        if (node.add == 0) return;
        node.left.val = node.add;
        node.right.val = node.add;
        node.left.add = node.add;
        node.right.add = node.add;
        node.add = 0;
    }

    class Node {
        Node left, right;
        int val, add;
    }
}

线段树(动态开点 - 估点)
估点的基本方式 🧀 求解常见「值域爆炸,查询有限」区间问题的几种方式。

简单来说,可以直接估算为 6×m×logn 即可,其中 m 为询问次数(对应本题就是 ps 的长度),而 n 为值域大小(对应本题可直接取成 1e9);而另外一个比较实用(避免估算)的估点方式可以「尽可能的多开点数」,利用题目给定的空间上界和创建的自定义类(结构体)的大小,尽可能的多开(不考虑字节对齐,或者结构体过大的情况,Java 的 128M128M 可以开到 5×106 以上)。

class Solution {
    class Node {
        // ls 和 rs 分别代表当前区间的左右子节点所在 tr 数组中的下标
        // val 代表当前区间的最大高度,add 为懒标记
        int ls, rs, val, add;
    }
    int N = (int)1e9, cnt = 0;
    Node[] tr = new Node[1000010];
    void update(int u, int lc, int rc, int l, int r, int v) {
        if (l <= lc && rc <= r) {
            tr[u].val = v;
            tr[u].add = v;
            return ;
        }
        pushdown(u);
        int mid = lc + rc >> 1;
        if (l <= mid) update(tr[u].ls, lc, mid, l, r, v);
        if (r > mid) update(tr[u].rs, mid + 1, rc, l, r, v);
        pushup(u);
    }
    int query(int u, int lc, int rc, int l, int r) {
        if (l <= lc && rc <= r) return tr[u].val;
        pushdown(u);
        int mid = lc + rc >> 1, ans = 0;
        if (l <= mid) ans = query(tr[u].ls, lc, mid, l, r);
        if (r > mid) ans = Math.max(ans, query(tr[u].rs, mid + 1, rc, l, r));
        return ans;
    }
    void pushdown(int u) {
        if (tr[u] == null) tr[u] = new Node();
        if (tr[u].ls == 0) {
            tr[u].ls = ++cnt;
            tr[tr[u].ls] = new Node();
        }
        if (tr[u].rs == 0) {
            tr[u].rs = ++cnt;
            tr[tr[u].rs] = new Node();
        }
        if (tr[u].add == 0) return ;
        int add = tr[u].add;
        tr[tr[u].ls].add = add; tr[tr[u].rs].add = add;
        tr[tr[u].ls].val = add; tr[tr[u].rs].val = add;
        tr[u].add = 0;
    }
    void pushup(int u) {
        tr[u].val = Math.max(tr[tr[u].ls].val, tr[tr[u].rs].val);
    }
    public List<Integer> fallingSquares(int[][] ps) {
        List<Integer> ans = new ArrayList<>();
        tr[1] = new Node();
        for (int[] info : ps) {
            int x = info[0], h = info[1], cur = query(1, 1, N, x, x + h - 1);
            update(1, 1, N, x, x + h - 1, cur + h);
            ans.add(tr[1].val);
        }
        return ans;
    }
}

时间复杂度:令 m 为查询次数,n 为值域大小,复杂度为 O(mlogn)
空间复杂度:O(mlogn)

729. 我的日程安排表 I

class MyCalendar {

    public MyCalendar() {
    }
    
    public boolean book(int start, int end) {
        // 先查询该区间是否为 0
        if (query(root, 0, N, start, end - 1) != 0) return false;
        // 更新该区间
        update(root, 0, N, start, end - 1, 1);
        return true;
    }

    class Node {      
        Node left, right;        
        int val, add;
    }
    private int N = (int) 1e9;
    private Node root = new Node();
    public void update(Node node, int start, int end, int l, int r, int val) {
        if (l <= start && end <= r) {
            node.val += val;
            node.add += val;
            return ;
        }
        pushDown(node);
        int mid = (start + end) >> 1;
        if (l <= mid) update(node.left, start, mid, l, r, val);
        if (r > mid) update(node.right, mid + 1, end, l, r, val);
        pushUp(node);
    }
    public int query(Node node, int start, int end, int l, int r) {
        if (l <= start && end <= r) return node.val;
        pushDown(node);
        int mid = (start + end) >> 1, ans = 0;
        if (l <= mid) ans = query(node.left, start, mid, l, r);
        if (r > mid) ans = Math.max(ans, query(node.right, mid + 1, end, l, r));
        return ans;
    }
    private void pushUp(Node node) {
        // 每个节点存的是当前区间的最大值
        node.val = Math.max(node.left.val, node.right.val);
    }
    private void pushDown(Node node) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
        if (node.add == 0) return ;
        node.left.val += node.add;
        node.right.val += node.add;
        node.left.add += node.add;
        node.right.add += node.add;
        node.add = 0;
    }
}

731. 我的日程安排表 II

class MyCalendarTwo {

    public MyCalendarTwo() {
    }
    
    public boolean book(int start, int end) {
        if (query(root, 0, N, start, end - 1) == 2) return false;
        update(root, 0, N, start, end - 1, 1);
        return true;
    }
}

732. 我的日程安排表 III

class MyCalendarThree {

    public MyCalendarThree() {
    }
    
    public int book(int start, int end) {
        // 只用到了 update
        update(root, 0, N, start, end - 1, 1);
        // 最大值即为根节点的值
        return root.val;
    }
}

715. Range 模块

class RangeModule {

    public RangeModule() {
    }
    
    public void addRange(int left, int right) {
        // 1 表示覆盖;-1 表示取消覆盖
        update(root, 1, N, left, right - 1, 1);
    }
    
    public boolean queryRange(int left, int right) {
        return query(root, 1, N, left, right - 1);
    }
    
    public void removeRange(int left, int right) {
        // 1 表示覆盖;-1 表示取消覆盖
        update(root, 1, N, left, right - 1, -1);
    }

    private int N = (int) 1e9;
    private Node root = new Node();
    public void update(Node node, int start, int end, int l, int r, int val) {
        if (l <= start && end <= r) {
            // 1 表示覆盖;-1 表示取消覆盖
            node.cover = val == 1;
            node.add = val;
            return ;
        }
        int mid = (start + end) >> 1;
        pushDown(node);
        if (l <= mid) update(node.left, start, mid, l, r, val);
        if (r > mid) update(node.right, mid + 1, end, l, r, val);
        pushUp(node);
    }
    public boolean query(Node node, int start, int end, int l, int r) {
        if (l <= start && end <= r) return node.cover;
        int mid = (start + end) >> 1;
        pushDown(node);
        // 查询左右子树是否被覆盖
        boolean ans = true;
        if (l <= mid) ans = ans && query(node.left, start, mid, l, r);
        if (r > mid) ans = ans && query(node.right, mid + 1, end, l, r);
        return ans;
    }
    private void pushUp(Node node) {
        node.cover = node.left.cover && node.right.cover;
    }
    private void pushDown(Node node) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
        if (node.add == 0) return ;
        node.left.cover = node.add == 1;
        node.right.cover = node.add == 1;
        node.left.add = node.add;
        node.right.add = node.add;
        node.add = 0;
    }
}

class Node {
    Node left, right;
    // 表示当前区间是否被覆盖
    boolean cover;
    int add;
}

2407. 最长递增子序列 II

class Solution {
    public int lengthOfLIS(int[] nums, int k) {
        // 单点更新,区间查询
        int ans = 0;
        for (int x:nums) {
            // 查询区间 [x - k, x - 1] 的最大值
            int cnt = query(root, 0, N, Math.max(0, x - k), x - 1);
            update(root, 0, N, x, ++cnt);
            ans = Math.max(ans, cnt);
        }
        return ans;
    }

    int N = (int) 1e5;
    Node root = new Node();
    void update(Node node, int start, int end, int x, int val) {
        if (start == end) {
            node.val = val;
            return ;
        }
        pushDown(node);
        int mid = (start + end) >> 1;
        if (x <= mid) update(node.left, start, mid, x, val);
        else update(node.right, mid + 1, end, x, val);
        pushUp(node);
    }
    
    int query(Node node, int start, int end, int l, int r) {
        if (l <= start && end <= r) return node.val;
        pushDown(node);
        int mid = start + end >> 1, ans = 0;
        if (l <= mid) ans = query(node.left, start, mid, l, r);
        if (r > mid) ans = Math.max(ans, query(node.right, mid + 1, end, l, r));
        return ans;
    }
    
    private void pushUp(Node node) {
        node.val = Math.max(node.left.val, node.right.val);
    }
    
    private void pushDown(Node node) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
    }
    
    class Node {        
        Node left, right;        
        int val;
    }
}

673. 最长递增子序列的个数

1 <= nums.length <= 2000
-106 <= nums[i] <= 106

值域范围不是特别大,可以直接用线段树保存整个值域区间。但因为数组的长度只有 2000,先对数组进行离散化处理。把数组中的元素 按照大小依次映射到 [0, len(nums) - 1] 这个区间。

构建一棵长度为 len(nums) 的线段树,其中每个线段树的结点保存一个二元组, val = [length, count] :以当前结点为结尾的子序列所能达到的最长递增子序列长度 length 和最长递增子序列对应的数量 count。

顺序遍历数组 nums。对于当前元素 nums[i]:
查找 [0, nums[i - 1]] 离散化后对应区间结点的二元组,也就是查找以区间 [0, nums[i - 1]] 上的点为结尾的子序列所能达到的最长递增子序列长度和其对应的数量,即 val = [length, count]。
如果所能达到的最长递增子序列长度为 0,则加入 nums[i] 之后最长递增子序列长度变为 1,且数量也变为 1。
如果所能达到的最长递增子序列长度不为 0,则加入 nums[i] 之后最长递增子序列长度 + 1,但数量不变。
计算的 val 值更新 nums[i] 对应节点的 val 值。然后继续向后遍历,重复进行第 3 ~ 4 步操作。
最后查询以区间 [0, nums[len(nums) - 1]] 上的点为结尾的子序列所能达到的最长递增子序列长度和其对应的数量。返回对应的数量即为答案。

此题最重要的是理解线段树的叶子节点所代表的含义,比如有一个叶子节点 [1,1] 它表示以数字1“结尾”的最长递增子序列的长度为多少,以及该长度对应的序列的个数为多少,这样就符合区间加法的性质了,比如 [1,2] 分为 [1,1] [2,2] 求出以1“结尾”的最长递增子序列的长度与个数,求出以 2 “结尾”的最长递增子序列的长度与个数,两个的长度如果相等且为 L, 个数分别为 n1, n2,则区间 [1, 2] 的长度为L的最长递增子序列的个数为n1 + n2; 如果两个的长度不等,则取较长的那个序列的长度与个数。至此区间 [1, 2] 中最长递增子序列的长度与个数就求出来了。这也是 merge 函数的思想。

class Solution {
    Node[] tree;
    public int findNumberOfLIS(int[] nums) {
        int n = nums.length;
        tree = new Node[4 * n];
        for(int i = 0; i < 4 * n; i ++) tree[i] = new Node();
        build(0, 0, n - 1);
        Map<Integer, Integer> map = new HashMap();
        Integer[] index = new Integer[n];
        Arrays.setAll(index, i -> i);
        Arrays.sort(index, (i, j) -> nums[i] - nums[j]);
        for(int i = 0; i < n; i++) map.put(nums[index[i]], i);
        for(int x : nums){
            int idx = map.get(x);
            int[] val = query(0, 0, idx - 1);
            if(val[0] == 0) val = new int[]{1, 1};
            else val = new int[]{val[0] + 1, val[1]};
            update(0, idx, val);
        }
        return query(0, 0, n - 1)[1];
    }
    
    void build(int node, int start, int end){
        tree[node].start = start;
        tree[node].end = end;
        if (start == end) {
            tree[node].val = new int[]{0, 0};
            return;
        }
        int mid = start + (end - start) / 2;
        int left_node = node * 2 + 1;
        int right_node = node * 2 + 2;            
        build(left_node, start, mid);
        build(right_node, mid + 1, end);
        tree[node].val = merge(tree[left_node].val, tree[right_node].val);  
    }

    void update(int node, int x, int[] val) {
        int start = tree[node].start;
        int end = tree[node].end;
        if (start == end) {
            tree[node].val = merge(tree[node].val, val);
            return;
        }       
        int mid = start + end >> 1;
        int left_node = node * 2 + 1;
        int right_node = node * 2 + 2;

        if (x <= mid) update(left_node, x, val);
        else update(right_node, x, val);
        tree[node].val = merge(tree[left_node].val, tree[right_node].val);   
    } 

    int[] query(int node, int l, int r) {
        int start = tree[node].start;
        int end = tree[node].end;
        if (l <= start && end <= r) return tree[node].val;
        if (end < l || start > r) return new int[]{0, 0}; 
        int mid = start + end >> 1;
        int left_node = node * 2 + 1;
        int right_node = node * 2 + 2;
        int[] a = {0, 0}, b = {0, 0}; // 默认值
        if (l <= mid) a = query(left_node, l, r);
        if (r > mid) b = query(right_node, l, r);
        return merge(a, b);
    }

    int[] merge(int[] x, int[] y){
        if(x[0] == y[0]){            
            return new int[]{x[0], x[1] + y[1]};
        }
        return x[0] > y[0] ? x : y;
    }
}

class Node {        
    int start = -1, end = -1;      
    int[] val = new int[2];
}

动态开点

class Solution {
    public int findNumberOfLIS(int[] nums) {
        for(int x : nums){
            if(x > R) R = x;
            if(x < L) L = x;
        }
        L--; // 为了处理最小的一个
        // 单点更新,区间查询
        for (int x : nums) {
            // 查询区间 [L, x - 1] 的最大值
            int[] val = query(root, L, R, L, x - 1);
            update(root, L, R, x, new int[]{val[0] + 1, val[1]});
        }
        return root.val[1];
    }
    
    int L = (int) 1e6, R = -(int)1e6; // [最小值 - 1, 最大值]
    Node root = new Node();
    void update(Node node, int start, int end, int x, int[] val) {
        if (start == end) {
            node.val = merge(node.val, val);
            return ;
        }
        pushDown(node);
        int mid = start + end >> 1;
        if (x <= mid) update(node.left, start, mid, x, val);
        else update(node.right, mid + 1, end, x, val);
        pushUp(node);
    }
    
    int[] query(Node node, int start, int end, int l, int r) {
        if (l <= start && end <= r) return node.val;
        pushDown(node);
        int mid = start + end >> 1;
        int[] a = {0, 1}, b = {0, 1}; // 默认值
        if (l <= mid) a = query(node.left, start, mid, l, r);
        if (r > mid) b = query(node.right, mid + 1, end, l, r);
        return merge(a, b);
    }
    
    void pushUp(Node node) {
        node.val = merge(node.left.val, node.right.val);
    }
    
    void pushDown(Node node) {
        if (node.left == null) node.left = new Node();
        if (node.right == null) node.right = new Node();
    }
    
    int[] merge(int[] x, int[] y){
        if(x[0] == y[0]){
            if(x[0] == 0) return new int[]{0, 1};
            return new int[]{x[0], x[1] + y[1]};
        }
        return x[0] > y[0] ? x : y;
    }
}

class Node {        
    Node left, right;        
    int[] val = new int[2];
}

2213. 由单个字符重复的最长子字符串

class Solution {
    char[] cs;
    public int[] longestRepeating(String s, String queryCharacters, int[] queryIndices) {
        cs = s.toCharArray();
        char[] qcs = queryCharacters.toCharArray();
        int m = qcs.length;
        int n = cs.length;
        Node root = build(0, n - 1);
        int[] res = new int[m];
        for (int i = 0; i < m; i++) {
            res[i] = query(root, 0, n - 1, queryIndices[i], qcs[i]);
        }
        return res;
    }

    private Node build(int l, int r) {
        Node node = new Node();
        if (l == r) return node;
        int mid = l + r >>> 1;
        node.left = build(l, mid);
        node.right = build(mid + 1, r);
        merge(node, l, r, mid);
        return node;
    }
	// 查询和更新一个函数
    private int query(Node root, int l, int r, int i, char c) {
        if (l == r) cs[i] = c;
        else {
            int mid = l + r >> 1;
            if (i <= mid) query(root.left, l, mid, i, c);
            else query(root.right, mid + 1, r, i, c);
            merge(root, l, r, mid);
        }
        return root.max;
    }

    void merge(Node root, int l, int r, int mid) {
        root.max = Math.max(root.left.max, root.right.max);
        root.pre = root.left.pre;
        root.suf = root.right.suf;
        if (cs[mid] == cs[mid + 1]) {
            root.max = Math.max(root.max, root.left.suf + root.right.pre);
            if (root.left.max == mid - l + 1) root.pre += root.right.pre;            
            if (root.right.max == r - mid) root.suf += root.left.suf;            
        }
    }

    static class Node {
        // 前缀,后缀,最大
        int pre = 1, suf = 1, max = 1;
        // 区间信息
        //int l, r; // 通过传参获得
        Node left = null, right = null;
        /*
        public Node(int l, int r) {
            this.l = l;
            this.r = r;
        }
        */
    }
}

1157. 子数组中占绝大多数的元素

class MajorityChecker {
    int[] arr;
    int n;
    List<Integer>[] rec = new List[20005];    
    Node root;
    public MajorityChecker(int[] arr) { 
        Arrays.setAll(rec, v -> new ArrayList());
        n = arr.length;                
        this.arr = arr;
        for (int i = 0; i < n; i++) rec[arr[i]].add(i);
        root = build(0, n - 1);
        for (int i = 1; i <= 20000; i++) rec[i].add(n + 1);
    }

    Node build(int l, int r) {
        Node node = new Node();
        if (l == r) {            
            node.val = new int[]{arr[r], 1};
            return node;
        }        
        int mid = l + r >> 1; 
        node.left = build(l, mid);
        node.right = build(mid + 1, r);       
        node.val = merge(node.left.val, node.right.val);  
        return node;          
    }
    
    int[] merge(int[] x, int[] y){
        if (x[0] == y[0]) return new int[]{x[0], x[1] + y[1]};           
        if (x[1] >= y[1]) return new int[]{x[0], x[1] - y[1]}; 
        return new int[]{y[0], y[1] - x[1]}; 
    }    
   
    int[] query(Node node, int l, int r, int x, int y) { 
        if (x <= l && r <= y) return node.val;
        int mid = l + r >> 1;
        if (y <= mid) return query(node.left, l, mid, x, y);
        if (x > mid) return query(node.right, mid + 1, r, x, y);
        return merge(query(node.left, l, mid, x, y), query(node.right, mid + 1, r, x, y));
    }
        
    public int query(int left, int right, int threshold) {
        int ask = query(root, 0, n - 1, left, right)[0];  
        if (search(rec[ask], right + 1) - search(rec[ask], left) < threshold) ask = -1;
        return ask;
    }
        
    int search(List<Integer> rec, int x) {
        int l = 0, r = rec.size() - 1;
        while (l < r) {
            int mid = l + r >> 1;
            if (rec.get(mid) < x) l = mid + 1;
            else r = mid;                
        }
         return r;
    }
}

class Node {
    Node left, right;
    int[] val = new int[2];
}

2286. 以组为单位订音乐会的门票

218. 天际线问题

315. 计算右侧小于当前元素的个数

327. 区间和的个数

406. 根据身高重建队列

493. 翻转对

850. 矩形面积 II

1505. 最多 K 次交换相邻数位后得到的最小整数

1521. 找到最接近目标值的函数值

1622. 奇妙序列

1649. 通过指令创建有序数组

1687. 从仓库到码头运输箱子

2080. 区间内查询数字的频率

2179. 统计数组中好三元组数目

2276. 统计区间中的整数数目

2424. 最长上传前缀

2426. 满足不等式的数对数目

LCP 05. 发 LeetCoin

LCP 09. 最小跳跃次数

LCP 27. 黑盒光线反射

LCP 52. 二叉搜索树染色

「值域爆炸,查询有限」区间问题的几种方式

729. 我的日程安排表 I

模拟

利用 book 操作最多调用 1000 次,可以使用一个数组存储所有已被预定的日期 [start, end - 1],对于每次 book 操作,检查当前传入的 [start, end] 是否会与已有的日期冲突,冲突返回 False,否则将 [start, end - 1] 插入数组并返回 True。

class MyCalendar {
    List<int[]> list = new ArrayList<>();
    public boolean book(int start, int end) {
        end--;
        for (int[] info : list) {
            int l = info[0], r = info[1];
            if (start > r || end < l) continue;
            return false;
        }
        list.add(new int[]{start, end});
        return true;
    }
}

时间复杂度:令 为 book 的最大调用次数,复杂度为 O(n2)
空间复杂度:O(n)

有序集合(红黑树)

解法一,每次的 book 操作都不可避免的需要遍历所有已存在的日期。

如果使用 TreeMap(底层为红黑树)来维护所有日期,可以避免对所有已存在的日期进行遍历。

class MyCalendar {
    TreeMap<Integer, Integer> tm = new TreeMap();    
    public boolean book(int start, int end) {
        Integer prev = tm.floorKey(start), next = tm.ceilingKey(start);
        if ((prev == null || tm.get(prev) <= start) && (next == null || end <= next)) {
            tm.put(start, end);
            return true;
        }
        return false;
    }
}

时间复杂度:令 为 book 的最大调用次数,复杂度为 O(n * log n)
空间复杂度:O(n)

线段树(动态开点)

对于常规的线段树实现来说,都是一开始就调用 build 操作创建空树,而线段树一般以「满二叉树」的形式用数组存储,因此需要 4 * n 的空间,并且这些空间在起始 build 空树的时候已经锁死。

如果一道题仅仅是 「值域很大」的离线题(提前知晓所有的询问),还能通过 「离散化」 来进行处理,将值域映射到一个小空间去,从而解决 MLE 问题。

但对于本题而言,由于「强制在线」的原因,无法进行「离散化」,同时值域大小达到 1e9 级别,因此如果想要使用「线段树」进行求解,只能采取「动态开点」的方式进行。

class MyCalendar {
    class Node {
        // ls 和 rs 分别代表当前节点的左右子节点在 tr 的下标
        // val 代表当前节点有多少数
        // add 为懒标记
        int ls, rs, add, val;
    }
    int N = (int)1e9, M = 120010, cnt = 1;
    Node[] tr = new Node[M];
    void update(int u, int lc, int rc, int l, int r, int v) {
        if (l <= lc && rc <= r) {
            tr[u].val += (rc - lc + 1) * v;
            tr[u].add += v;
            return ;
        }
        lazyCreate(u);
        pushdown(u, rc - lc + 1);
        int mid = lc + rc >> 1;
        if (l <= mid) update(tr[u].ls, lc, mid, l, r, v);
        if (r > mid) update(tr[u].rs, mid + 1, rc, l, r, v);
        pushup(u);
    }
    int query(int u, int lc, int rc, int l, int r) {
        if (l <= lc && rc <= r) return tr[u].val;
        lazyCreate(u);
        pushdown(u, rc - lc + 1);
        int mid = lc + rc >> 1, ans = 0;
        if (l <= mid) ans = query(tr[u].ls, lc, mid, l, r);
        if (r > mid) ans += query(tr[u].rs, mid + 1, rc, l, r);
        return ans;
    }
    void lazyCreate(int u) {
        if (tr[u] == null) tr[u] = new Node();
        if (tr[u].ls == 0) {
            tr[u].ls = ++cnt;
            tr[tr[u].ls] = new Node();
        }
        if (tr[u].rs == 0) {
            tr[u].rs = ++cnt;
            tr[tr[u].rs] = new Node();
        }
    }
    void pushdown(int u, int len) {
        tr[tr[u].ls].add += tr[u].add; tr[tr[u].rs].add += tr[u].add;
        tr[tr[u].ls].val += (len - len / 2) * tr[u].add; tr[tr[u].rs].val += len / 2 * tr[u].add;
        tr[u].add = 0;
    }
    void pushup(int u) {
        tr[u].val = tr[tr[u].ls].val + tr[tr[u].rs].val;
    }
    public boolean book(int start, int end) {
        if (query(1, 1, N + 1, start + 1, end) > 0) return false;
        update(1, 1, N + 1, start + 1, end, 1);
        return true;
    }
}

仓库:https://github.com/SharingSource/LogicStack-LeetCode 。


举报

相关推荐

0 条评论