题目大意:一个$n$个点的树,树上有$m$个点对$(a,b)$,找到一个点$x$,使得$max(dis(x,a_i)+dis(x,b_i))$最小
如果做过幻想乡的战略游戏这道题,应该这道题的思路一眼能看出来
首先如果从一个点向能使答案变小的子树上走,那么从子树上一定不会再回到这个点
所以考虑一个暴力,即每次计算所有子树的答案,然后向最优的方向走
这显然是正确的,但是不够优秀
我们再深入分析一下这道题,可以发现,当且仅当所有的距离等于最大值的点对都在它的一个子树内时才可能使得答案变优
很好理解,因为如果不在通一个子树内,不论向任何地方走,总会有点对的最大值变得更大
然后这样我们就可以用点分治的$getroot$来优化这个过程,复杂度为$nlogn$
代码:
1 #include<iostream>
2 #include<cstdio>
3 #include<cstring>
4 #include<cstdlib>
5 #define M 100010
6 using namespace std;
7 int n,m,num,rt,S,ans=1e9;
8 int head[M],size[M],maxn[M],bel[M],dis[M],u[M],v[M],st[M];
9 bool vis[M];
10 struct point{int to,next,dis;}e[M<<1];
11 void add(int from,int to,int dis)
12 {
13 e[++num].next=head[from];
14 e[num].to=to;
15 e[num].dis=dis;
16 head[from]=num;
17 }
18 void getroot(int x,int fa)
19 {
20 size[x]=maxn[x]=1;
21 for(int i=head[x];i;i=e[i].next)
22 {
23 int to=e[i].to;
24 if(to==fa||vis[to]) continue;
25 getroot(to,x),size[x]+=size[to];
26 maxn[x]=max(maxn[x],size[to]);
27 }
28 maxn[x]=max(maxn[x],S-size[x]);
29 if(maxn[x]<maxn[rt]) rt=x;
30 }
31
32 void dfs(int x,int fa,int id)
33 {
34 bel[x]=id;
35 for(int i=head[x];i;i=e[i].next)
36 if(e[i].to!=fa)
37 {
38 dis[e[i].to]=dis[x]+e[i].dis;
39 dfs(e[i].to,x,id);
40 }
41 }
42
43 void solve(int x)
44 {
45 if(vis[x]) {printf("%d\n",ans);exit(0);}
46 vis[x]=true,dis[x]=0;
47 for(int i=head[x];i;i=e[i].next)
48 {
49 dis[e[i].to]=e[i].dis;
50 dfs(e[i].to,x,e[i].to);
51 }
52 int MX=0,top=0,pos=0;
53 for(int i=1;i<=m;i++)
54 {
55 if(dis[u[i]]+dis[v[i]]>MX)
56 {
57 MX=dis[u[i]]+dis[v[i]];
58 st[top=1]=i;
59 }
60 else if(dis[u[i]]+dis[v[i]]==MX)
61 st[++top]=i;
62 }
63 ans=min(ans,MX);
64 for(int i=1;i<=top;i++)
65 {
66 if(bel[u[st[i]]]!=bel[v[st[i]]])
67 {
68 printf("%d\n",ans);
69 exit(0);
70 }
71 else
72 {
73 if(!pos) pos=bel[u[st[i]]];
74 else if(pos!=bel[u[st[i]]])
75 {
76 printf("%d\n",ans);
77 exit(0);
78 }
79 }
80 }
81 S=size[pos],rt=0;
82 getroot(pos,0),solve(rt);
83 }
84
85 int main()
86 {
87 scanf("%d%d",&n,&m);
88 for(int i=1;i<n;i++)
89 {
90 int a,b,c;scanf("%d%d%d",&a,&b,&c);
91 add(a,b,c),add(b,a,c);
92 }
93 for(int i=1;i<=m;i++) scanf("%d%d",&u[i],&v[i]);
94 S=maxn[0]=n,getroot(1,0),solve(rt);
95 return 0;
96 }