题目
题面
简要题意:
给定一棵
n
n
n 个节点的树,初始时
1
1
1 号节点为红色,其余为蓝色。
要求支持如下操作:
1. 将一个节点变为红色。
2. 询问节点
u
u
u 到最近红色节点的距离。
共
q
q
q 次操作。
1
≤
n
,
q
≤
1
0
5
1 \leq n, q \leq 10^5
1≤n,q≤105。
分析:
非常 典 的一道题。
我们首先考虑一种 修改
O
(
n
)
O(n)
O(n),查询
O
(
1
)
O(1)
O(1) 的算法:每次改变一个点的颜色就把它放进队列里跑一遍 bfs,去更新其它点到红点的最小值。
接着我们考虑一种 修改 O ( 1 ) O(1) O(1),查询 O ( n ) O(n) O(n) 的算法:每次 O ( 1 ) O(1) O(1) 标记一个点是否为红色。然后每次查询枚举红色的点并计算距离,时间复杂度是 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n) 的。
我们考虑如何平衡这两种算法。
因为 bfs 可以在
O
(
n
)
O(n)
O(n) 的复杂度内跑 多个终点的最短路,因此我们可以将红点储存起来一起跑bfs。所以可以对操作进行分块。
设块长为
S
S
S,我们每一次从当前块到下一块时,我们把当前块的 所有染红操作的点 放进队列里面跑 bfs 更新 其它点的
d
i
s
dis
dis 值。然后对于当前块的询问,我们扫块内的所有操作,如果为
1
1
1 操作,那么我们
O
(
l
o
g
2
n
)
O(log_2n)
O(log2n) 的复杂度内查出询问点和修改点的距离并与
d
i
s
dis
dis 数组取
m
i
n
min
min 即可。
时间复杂度是
O
(
q
S
×
n
+
q
×
S
×
l
o
g
2
n
)
O(\frac{q}{S} \times n + q \times S \times log_2n)
O(Sq×n+q×S×log2n) 的,当
S
=
n
l
o
g
2
n
S = \sqrt{\frac{n}{log_2n}}
S=log2nn 时复杂度最小,为
O
(
q
n
l
o
g
2
n
)
O(q\sqrt{nlog_2n})
O(qnlog2n)。
CODE:
#include<bits/stdc++.h>// 好题
#define pb push_back
using namespace std;
const int N = 1e5 + 10;
typedef pair< int, int > PII;
int n, Q, dis[N], blo, op, x, bl, u, v, dep[N], fa[N][25];
inline int read(){
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){if(c == '-') f = -1; c = getchar();}
while(isdigit(c)){x = (x << 1) + (x << 3) + (c ^ 48); c = getchar();}
return x * f;
}
vector< int > E[N];
vector< PII > vec[N];
queue< int > q;
void dfs(int x, int fat){
dep[x] = dep[fat] + 1; fa[x][0] = fat;
for(int i = 1; i <= 20; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1];
for(auto v : E[x]){
if(v == fat) continue;
dfs(v, x);
}
}
int LCA(int x, int y){
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; i--){
if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
}
if(x == y) return x;
for(int i = 20; i >= 0; i--){
if(fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
}
return fa[x][0];
}
void bfs(){
while(!q.empty()){
int u = q.front(); q.pop();
for(auto v : E[u]){
if(dis[v] > dis[u] + 1){
dis[v] = dis[u] + 1;
q.push(v);
}
}
}
}
int main(){
memset(dis, 0x3f, sizeof dis);
n = read(), Q = read();
for(int i = 1; i < n; i++){
u = read(), v = read();
E[u].pb(v); E[v].pb(u);
}
dfs(1, 0);
blo = max(1, (int)sqrt(1.0 * n / log2(n)));
for(int i = 1; i <= Q; i++){
op = read(), x = read();
bl = (i - 1) / blo + 1;
vec[bl].pb(make_pair(op, x));
}
dis[1] = 0;
q.push(1);
bfs();
for(int i = 1; i <= bl; i++){
for(int j = 0; j < vec[i].size(); j++){
int op = vec[i][j].first, x = vec[i][j].second;
if(op == 1) dis[x] = 0, q.push(x);
else{
int y = dis[x];
for(int k = 0; k < j; k++){
if(vec[i][k].first == 1) y = min(y, dep[x] + dep[vec[i][k].second] - 2 * dep[LCA(vec[i][k].second, x)]);
}
printf("%d\n", y);
}
}
bfs();
}
return 0;
}