Lowest Common Ancestor : Farach-Colton & Bender Algorithm
假设 G 是一棵树。对于每一个形式为 ( u , v ) (u,v) (u,v) 的查询,我们想找到节点 u u u 和 v v v$ 的最近共同祖先,也就是说,我们想找到一个节点 w w w ,它位于从 u u u 和 v v v 到根节点的路径上,如果有多个节点,我们选择离根节点最远的那个。换句话说,所需的节点 w w w 是 u u u 和 v v v 的最近祖先。特别情况下,如果 u u u 是 v v v 的祖先,那么 u u u 就是它们的最近共同祖先。
本文将描述的算法是由 Farach-Colton 和 Bender 共同开发的,是一种渐进式的最优算法。
Algorithm
我们使用经典的 LCA 问题归约为 RMQ 问题。用 DFS 遍历树的所有节点,并保留一个包含所有访问过的节点和这些节点的高度的数组。两个节点 u u u 和 v v v 的 LCA 是在遍历中出现的 u u u 和 v v v 之间的节点,它的高度最小。
在下图(表)中,你可以看到一种情况下的欧拉遍历和访问过的节点以及它们的高度。
在 Lowest Common Ancestor (可能需要科学上网)一文中读到更多关于这种归约的信息。在上述这篇文章中,一个范围的最小值是通过平方分解求得的
O
(
N
)
O\left( \sqrt{N}\right)
O(N) 或者用线段树求得的
O
(
log
N
)
O\left( \log N\right)
O(logN) 。
而在本文中,我们将探讨如何能在 O ( 1 ) O\left( 1\right) O(1) 的时间内解决给定的范围最小值查询,同时仍然只用 O ( N ) O\left( N\right) O(N) 的时间复杂度进行预处理。
需要注意的是,归约的 RMQ 问题是非常具体的:数组中任何两个相邻的元素正好相差 1 1 1(因为数组的元素只不过是按遍历顺序访问的节点的高度,我们要么去找一个后代,在这种情况下,下一个元素大 1 1 1 ;要么回到祖代,在这种情况下,下一个元素小 1 1 1 )。Farach-Colton 和 Bender 算法描述的正是这种专门的 RMQ 问题的解决方案。
让我们用 A A A 来表示我们要对其进行范围最小查询的数组。而 N N N 将是 A A A 的大小。
有一个简单的数据结构,需要进行 O ( N l o g N ) O\left( N log N\right) O(NlogN) 的预处理,我们可以用它来解决 RMQ 问题,其中每个查询为 O ( 1 ) O(1) O(1) 的 稀疏表 。我们创建一个表 T T T ,其中每个元素 T [ i ] [ j ] T[i][j] T[i][j] 等于 A A A 在区间 [ i , i + 2 j − 1 ] [i,i+2^{j}-1] [i,i+2j−1] 中的最小值。显然 0 ≤ j ≤ ⌈ l o g N ⌉ 0\leq j\leq\lceil log N\rceil 0≤j≤⌈logN⌉ ,因此稀疏表的大小将是 O ( N l o g N ) O\left( Nlog N\right) O(NlogN) ,你可以在 O ( N l o g N ) O(Nlog N) O(NlogN) 内轻松建立表,其中 T [ i ] [ j ] = m i n ( T [ i ] [ j − 1 ] , T [ i + 2 j − 1 ] [ j − 1 ] ) T[i][j]=min(T[i][j-1],T[i+2^{j-1}][j-1]) T[i][j]=min(T[i][j−1],T[i+2j−1][j−1]) 。
我们如何用这种数据结构在 O ( 1 ) O(1) O(1) 内回答一条查询 RMQ ?假设收到的查询是 [ l , r ] [l,r] [l,r] ,那么答案是 m i n ( T [ l ] [ s z ] , T [ r − 2 s z + 1 ] [ s z ] ) min(T[l][sz],T[r-2^{sz}+1][sz]) min(T[l][sz],T[r−2sz+1][sz]) ,其中 s z sz sz 是最大的指数,使 2 s z 2^{sz} 2sz 不大于范围长度 r − l + 1 r-l+1 r−l+1 。事实上,我们可以把范围 [ l , r ] [l,r] [l,r] 覆盖在两个长度为 2 s z 2^{sz} 2sz 的片段上,一端从 l l l 开始,另一端在 r r r 结束。这些片段是重叠的,但这并不影响我们的计算。为了真正实现每个查询的时间复杂度为 O ( 1 ) O(1) O(1) ,我们需要知道从 1 1 1 到 N N N 的所有可能的长度的 s z sz sz 值,这是很容易预先计算的。
现在我们想将预处理的复杂度降低到 O ( N ) O\left( N\right) O(N) 。
我们把数组 A A A 分成大小为 K = l o g 2 N 2 K=\dfrac{log_{2} N}{2} K=2log2N 的块。对于每个块,我们计算出最小的元素,并将其存储在数组 B B B 中, B B B 的大小为 N K \dfrac{N}{K} KN 。我们从数组 B B B 构建一个稀疏表。它的规模和时间的复杂度为: N K l o g ( N K ) = 2 N l o g N l o g ( 2 N l o g N ) = 2 N l o g N ( 1 + l o g ( N l o g N ) ) ≤ 2 N l o g N + 2 N = O ( N ) \dfrac{N}{K}log \left( \dfrac{N}{K}\right)=\dfrac{2N}{log N}log \left( \dfrac{2N}{log N}\right)=\dfrac{2N}{log N}\left( 1+log\left( \dfrac{N}{log N}\right)\right)\leq \dfrac{2N}{log N}+2N=O\left( N\right) KNlog(KN)=logN2Nlog(logN2N)=logN2N(1+log(logNN))≤logN2N+2N=O(N)
现在我们只需要学习如何快速解决每个区块内的最小范围查询。事实上,如果收到的范围最小查询是 [ l , r ] [l,r] [l,r] ,并且 l l l 和 r r r 在不同的区块中,那么答案就是以下三个值的最小值:从 l l l 开始的 l l l 区块的后缀的最小值,从 r r r 结束的 r r r 区块的前缀的最小值,以及这些区块之间的最小值。这之间的区块的最小值可以用稀疏表在 O ( 1 ) O\left( 1\right) O(1) 内回答。因此,我们只剩下区块内的最小值查询。
这里我们将利用数组的属性。请记住,数组中只是树中的高度值将总是相差 1 1 1 。如果我们去掉一个区块的第一个元素,并将其与区块中的其他每一项相减,那么每个区块都可以由一个长度为 K − 1 K-1 K−1 的序列来识别,该序列由数字 1 1 1 和 0 0 0 组成,由于这些区块非常小,只有几个不同的序列可以出现。可能的序列数量为: 2 K − 1 = 2 log N 2 − 1 = 2 l o g N 2 = N 2 2^{K-1}=2^{\dfrac{\log N}{2}-1}=\dfrac{\sqrt{2^{logN}}}{2}=\dfrac{\sqrt{N}}{2} 2K−1=22logN−1=22logN=2N
因此,不同区块的数量是 O ( N ) O\left( \sqrt{N}\right) O(N) ,因此我们可以在 O ( N K 2 ) = O ( N log 2 N ) = O ( N ) O\left( \sqrt{N}K^{2}\right)= O\left( \sqrt{N}\log ^{2}N\right)=O\left( N\right) O(NK2)=O(Nlog2N)=O(N) 时间内预先计算出所有不同区块内的范围最小值查询结果。为了实现这一点,我们可以用一个长度为 K − 1 K-1 K−1 的比特掩码来描述一个区块的特征(适合于标准int),并将最小值的索引存储在一个大小为 O ( N log 2 N ) O\left( \sqrt{N}\log ^{2}N\right) O(Nlog2N) 的数组 b l o c k [ m a s k ] [ l ] [ r ] block[mask][l][r] block[mask][l][r] 中。
因此,我们学会了如何在每个区块内预先计算范围内的最小查询,以及在一个区块范围内的范围内的最小查询,所有这些都在 O ( N ) O\left( N\right) O(N) 的时间中。通过这些预计算,我们最多使用四个预先计算的值:包含 l l l 的区块的最小值,包含 r r r 的区块的最小值,以及它们之间区块重叠段的两个最小值,就可以在 O ( 1 ) O\left( 1\right) O(1) 内解决每个查询。
Implementation
int n;
vector<vector<int>> adj;
int block_size, block_cnt;
vector<int> first_visit;
vector<int> euler_tour;
vector<int> height;
vector<int> log_2;
vector<vector<int>> st;
vector<vector<vector<int>>> blocks;
vector<int> block_mask;
void dfs(int v, int p, int h) {
first_visit[v] = euler_tour.size();
euler_tour.push_back(v);
height[v] = h;
for (int u : adj[v]) {
if (u == p)
continue;
dfs(u, v, h + 1);
euler_tour.push_back(v);
}
}
int min_by_h(int i, int j) {
return height[euler_tour[i]] < height[euler_tour[j]] ? i : j;
}
void precompute_lca(int root) {
// get euler tour & indices of first occurences
first_visit.assign(n, -1);
height.assign(n, 0);
euler_tour.reserve(2 * n);
dfs(root, -1, 0);
// precompute all log values
int m = euler_tour.size();
log_2.reserve(m + 1);
log_2.push_back(-1);
for (int i = 1; i <= m; i++)
log_2.push_back(log_2[i / 2] + 1);
block_size = max(1, log_2[m] / 2);
block_cnt = (m + block_size - 1) / block_size;
// precompute minimum of each block and build sparse table
st.assign(block_cnt, vector<int>(log_2[block_cnt] + 1));
for (int i = 0, j = 0, b = 0; i < m; i++, j++) {
if (j == block_size)
j = 0, b++;
if (j == 0 || min_by_h(i, st[b][0]) == i)
st[b][0] = i;
}
for (int l = 1; l <= log_2[block_cnt]; l++) {
for (int i = 0; i < block_cnt; i++) {
int ni = i + (1 << (l - 1));
if (ni >= block_cnt)
st[i][l] = st[i][l-1];
else
st[i][l] = min_by_h(st[i][l-1], st[ni][l-1]);
}
}
// precompute mask for each block
block_mask.assign(block_cnt, 0);
for (int i = 0, j = 0, b = 0; i < m; i++, j++) {
if (j == block_size)
j = 0, b++;
if (j > 0 && (i >= m || min_by_h(i - 1, i) == i - 1))
block_mask[b] += 1 << (j - 1);
}
// precompute RMQ for each unique block
int possibilities = 1 << (block_size - 1);
blocks.resize(possibilities);
for (int b = 0; b < block_cnt; b++) {
int mask = block_mask[b];
if (!blocks[mask].empty())
continue;
blocks[mask].assign(block_size, vector<int>(block_size));
for (int l = 0; l < block_size; l++) {
blocks[mask][l][l] = l;
for (int r = l + 1; r < block_size; r++) {
blocks[mask][l][r] = blocks[mask][l][r - 1];
if (b * block_size + r < m)
blocks[mask][l][r] = min_by_h(b * block_size + blocks[mask][l][r],
b * block_size + r) - b * block_size;
}
}
}
}
int lca_in_block(int b, int l, int r) {
return blocks[block_mask[b]][l][r] + b * block_size;
}
int lca(int v, int u) {
int l = first_visit[v];
int r = first_visit[u];
if (l > r)
swap(l, r);
int bl = l / block_size;
int br = r / block_size;
if (bl == br)
return euler_tour[lca_in_block(bl, l % block_size, r % block_size)];
int ans1 = lca_in_block(bl, l % block_size, block_size - 1);
int ans2 = lca_in_block(br, 0, r % block_size);
int ans = min_by_h(ans1, ans2);
if (bl + 1 < br) {
int l = log_2[br - bl - 1];
int ans3 = st[bl+1][l];
int ans4 = st[br - (1 << l)][l];
ans = min_by_h(ans, min_by_h(ans3, ans4));
}
return euler_tour[ans];
}