Link
一道我觉得非常好的题目,有些地方还没想清楚,需要再回顾一下。
const int maxn = 3e5 + 10;
const int maxm = 6e5 + 10;
const int P = 1e9 + 7; //998244353
const int INF = 0x3f3f3f3f;
const double eps=1e-7;
int n, m;
int Log2[maxn], fa[maxn][30], dep[maxn];
bool vis[maxn];
int head[maxn];
int p;
struct Edge {
int to, dis = 1, next;
}edge[maxm];
void dfs(int cur = 1, int fath = 0) {
if(vis[cur]) return;
vis[cur] = true;
dep[cur] = dep[fath] + 1;
fa[cur][0] = fath;
for(int i = 1; i <= Log2[dep[cur]]; i++)
fa[cur][i] = fa[fa[cur][i-1]][i-1];
for(int i = head[cur]; i; i = edge[i].next)
dfs(edge[i].to, cur);
}
int LCA(int a, int b) {
if(dep[a] > dep[b])
swap(a, b);
while(dep[a] != dep[b])
b = fa[b][Log2[dep[b]-dep[a]]];
if(a == b)
return a;
for(int k = Log2[dep[a]]; k >= 0; k--) //跳跃长度从长到短
if(fa[a][k] != fa[b][k]) {
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
void init() {
for(int i = 1; i <= n; i++) {
dep[i] = 0;
head[i] = 0;
}
p = 0;
for(int i = 2; i <= n; i++)
Log2[i] = Log2[i / 2] + 1;
}
void add_edge(int u, int v, int w = 1) {
p++;
edge[p].to = v;
edge[p].dis = w;
edge[p].next = head[u];
head[u] = p;
}
int w[maxn];
vector< pii > op1[maxn], op2[maxn];
int d1[maxn << 1], d2[maxn << 1];
int ans[maxn];
void update(int s, int t) {
int p = LCA(s, t);
op1[s].pb(make_pair(dep[s], 1));
op1[fa[p][0]].pb(make_pair(dep[s], -1));
op2[t].pb(make_pair(dep[s]-2*dep[p]+n, 1));
op2[p].pb(make_pair(dep[s]-2*dep[p]+n, -1));
}
void dfs2(int u, int last) {
int v1 = w[u] + dep[u];
int v2 = w[u] - dep[u] + n;
int res1 = d1[v1], res2 = d2[v2];
for(int i = head[u]; i; i = edge[i].next) {
int v = edge[i].to;
if(v == last) continue;
dfs2(v, u);
}
for(int i = 0; i < op1[u].size(); i++)
d1[op1[u][i].first] += op1[u][i].second;
for(int i = 0; i < op2[u].size(); i++)
d2[op2[u][i].first] += op2[u][i].second;
ans[u] = (d1[v1] - res1) + (d2[v2] - res2);
}
void solve() {
cin >> n >> m;
init();
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add_edge(u, v);
add_edge(v, u);
}
dep[0] = -1;
dfs();
for(int i = 1; i <= n; i++) cin >> w[i];
for(int i = 1; i <= m; i++) {
int s, t;
cin >> s >> t;
update(s, t);
}
dfs2(1, 0);
for(int i = 1; i <= n; i++)
cout << ans[i] << ' ';
cout << endl;
}