文章目录
传送门
题意:
给你一颗 n n n个点的树,每条边为黑色或者白色,问满足以下条件的路径条数:路径上存在一个不是端点的点,使得两端点到该点的两条路径上两种颜色的边数相等。
1 ≤ n ≤ 100000 1\le n\le 100000 1≤n≤100000
思路:
统计树上路径问题显然需要用到点分治了,这个题维护的信息比较麻烦,想明白了思路还需要考虑如何做才能使代码变得简便好写。
考虑点分治每个步骤需要算的贡献,在选出重心之后,之后的子树中就不会包含重心这个点,所以贡献需要算每颗包含重心的子树内的路径以及这些子树之间构成的路径的贡献,一定不要忘记算包含重心的子树贡献,我由于忘记了调了一年。。
让后就是遍历每颗子树了,将 0 0 0看成 − 1 -1 −1,我们记三个变量,从重心到当前点的前缀和,这样就可以判断这个点到重心的路径上是否有一段和为0了,让后再用数组记一下有多少从重心开始的路径和为 x x x,即 m p [ x ] [ 0 ] mp[x][0] mp[x][0],让后如果第一个变量大于 0 0 0的话,那么也让 m p [ x ] [ 1 ] + + mp[x][1]++ mp[x][1]++。这个时候统计贡献就比较容易了,当有一段和为 0 0 0的时候,此时假设前缀和是 x x x,那么加上 m p [ x ] [ 0 ] mp[x][0] mp[x][0]即可,否则加上 m p [ x ] [ 1 ] mp[x][1] mp[x][1]。让后一颗包含重心的子树内自己贡献的话就是有两段和为 0 0 0,并且当前这个点的前缀和为 0 0 0,让答案加 1 1 1即可。
#include<bits/stdc++.h>
#define X first
#define Y second
#define L (u<<1)
#define R (u<<1|1)
#define Mid (tr[u].l+tr[u].r>>1)
#define pb push_back
using namespace std;
const int N=201000,INF=0x3f3f3f3f,mod=1e9+7,P=100000;
typedef long long LL;
typedef pair<int,int> PII;
int n,k;
vector<PII>v[N];
vector<int>q,qt;
bool st[N];
int mp[N][2],dis[N];
vector<int>all;
int get_size(int u,int f) {
if(st[u]) return 0;
int ans=1;
for(auto x:v[u]) if(x.X!=f) {
ans+=get_size(x.X,u);
}
return ans;
}
int get_wc(int u,int f,int tot,int &wc) {
if(st[u]) return 0;
int sum=1,mx=0;
for(auto x:v[u]) {
if(x.X==f) continue;
int now=get_wc(x.X,u,tot,wc);
mx=max(mx,now); sum+=now;
}
mx=max(mx,tot-sum);
if(mx<=tot/2) wc=u;
return sum;
}
LL get_dis_ans(int u,int f,int w) {
if(st[u]) return 0;
LL ans=0;
if(dis[w+P]) {
ans+=mp[P-w][0];
if(w==0&&dis[P]>=2) ans++;
} else ans+=mp[P-w][1];
dis[w+P]++;
for(auto x:v[u]) {
if(x.X==f) continue;
ans+=get_dis_ans(x.X,u,w+x.Y);
}
dis[w+P]--;
return ans;
}
void get_dis(int u,int f,int w) {
if(st[u]) return;
mp[w+P][0]++;
if(dis[w+P]>0) {
mp[w+P][1]++;
}
all.pb(w+P);
dis[w+P]++;
for(auto x:v[u]) {
if(x.X==f) continue;
get_dis(x.X,u,w+x.Y);
}
dis[w+P]--;
}
LL calc(int u) {
if(st[u]) return 0;
get_wc(u,-1,get_size(u,-1),u);
LL ans=0; st[u]=true;
for(auto x:v[u]) {
ans+=get_dis_ans(x.X,-1,x.Y);
get_dis(x.X,-1,x.Y);
}
for(auto x:all) mp[x][0]=mp[x][1]=0;
all.clear();
for(auto x:v[u]) ans+=calc(x.X);
return ans;
}
void solve() {
scanf("%d",&n);
dis[P]=1;
for(int i=1;i<=n-1;i++) {
int a,b,c; scanf("%d%d%d",&a,&b,&c);
if(c==0) c=-1;
v[a].pb({b,c});
v[b].pb({a,c});
}
printf("%lld\n",calc(1));
}
int main() {
int _=1;
while(_--) {
solve();
}
return 0;
}
/*
5
1 2 1
1 3 0
3 4 0
4 5 1
*/