0
点赞
收藏
分享

微信扫一扫

树链剖分(

Yaphets_巍 2022-03-16 阅读 117
c++
#include<iostream>
#include<cstring>
#include<vector>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=1e6+50;
struct edge{
    int to,nex;
}e[N<<1];
int head[N],n;
int st[N],end[N];
int siz[N];
vector<int>v[N];
int dep[N];
int num;
int fa[N][100];
bool cmp(int x,int y)
{
    return dep[x]>dep[y];
}
void add(int x,int y)
{
    e[num]=(edge){y,head[x]};
    head[x]=num++;
}
void dfs(int x,int pre)
{
    fa[x][0]=pre;
    dep[x]=dep[pre]+1;
    int too;
    for(int i=head[x];i!=-1;i=e[i].nex)
    {
        too=e[i].to;
        if(too==pre)
        continue;
        dfs(too,x);
        siz[x]+=siz[too];
    }
}
int lca(int u,int v)
{
 if(dep[u]>dep[v]) swap(u,v);
 int temp=dep[v]-dep[u];
 for(int i=0;(1<<i)<=temp;++i)
 {
     if((1<<i)&temp)
     v=fa[v][i];
 }
 if(u==v)
 return u;
 for(int i=(int)log2(n*1.0);i>=0;--i)
 {
     if(fa[u][i]!=fa[v][i])
     {
         u=fa[u][i];
         v=fa[v][i];
     }
 }
 return fa[u][0];
}
int main ()
{
    cin>>n;
    int x;
    memset(head,-1,sizeof(head));
    
    for(int i=1;i<=n;++i)
    {
    	siz[i]=1;
        cin>>x;
        v[x].push_back(i);
    }
    int uu,vv;

    for(int i=1;i<n;++i)
    {
     cin>>uu>>vv;
     add(uu,vv);
     add(vv,uu);
    }
    dfs(1,0);
    for(int j=0;(1<<(j+1))<n;++j)
    {
     for(int i=1;i<=n;++i)
     {
         if(fa[i][j]<=0) fa[i][j+1]=0;
         else fa[i][j+1]=fa[fa[i][j]][j];
     }
    }
    int p,f,tt;
    long long ans=0;
    for(int i=1;i<=n;++i)
    {
        if(v[i].size()==0)
        {
            ans=(1ll*n*(n-1))/2;
            cout<<ans<<"\n";
        }
        else if(v[i].size()==1)
        {
        	ans=0;
        	p=v[i][0];
        	tt=siz[p]-1;
        	for(int kk=head[p];kk!=-1;kk=e[kk].nex)
        	{
                
                if(tt-siz[e[kk].to]>0&&e[kk].to!=fa[p][0])
        		ans+=1ll*(tt-siz[e[kk].to])*siz[e[kk].to];
			}
			ans/=2;
			ans+=1ll*(n-siz[p])*(siz[p]-1)+1ll*(n-1);
            cout<<ans<<"\n";
            
        }
        else {
           sort(v[i].begin(),v[i].end(),cmp);
           x=v[i][0];
           f=0;
           for(int j=1;j<v[i].size();++j)
           {
               p=v[i][j];
               if(lca(x,p)==p)
               continue;
               else {
                   f=j;
                   break;
               }
           }
           if(f==0)
           {
           	//cout<<p<<endl;
           	//cout<<x<<endl;
               ans=n;
               for(int kk=head[p];kk!=-1;kk=e[kk].nex)
               {
                   if(e[kk].to==fa[p][0])
                   continue;
                   if(lca(x,e[kk].to)==e[kk].to)
                    ans-=siz[e[kk].to];
               }
               ans=1ll*siz[x]*ans;
               cout<<ans<<endl;
           }
           else {
               int ff=0;
               int lcaa=lca(x,p);
                for(int j=1;j<v[i].size();++j)
                {
                  if((lca(v[i][j],x)!=v[i][j]&&lca(v[i][j],x)!=x)&&(lca(v[i][j],p)!=v[i][j]&&lca(v[i][j],p)!=p))
                  {
                      ff=1;
                      break;
                  }
                  if(dep[lcaa]>dep[v[i][j]])
                  {
                      ff=1;
                      break;
                  }
                }
                if(ff)
                {
                    cout<<0<<endl;
                }
                else {
                    
                   ans=1ll*(siz[x]-1)*(siz[p]-1)+1ll+1ll*(siz[x]-1)+1ll*(siz[p]-1);
                   cout<<ans<<endl;
                }
           }
        }
    }
    /*for(int i=1;i<=n;++i)
    {
    	cout<<siz[i]<<" ";
	}
	cout<<endl;
	for(int i=1;i<=n;++i)
	{
		cout<<dep[i]<<" ";
	}
	cout<<endl;
	for(int i=1;i<=n;++i)
	{
		cout<<fa[i][0]<<" ";
	}
	cout<<endl;
    */
    return 0;
}

 

举报

相关推荐

0 条评论