传送门:https://atcoder.jp/contests/abc163/tasks/abc163_f
题目大意:一颗n个节点的树,每个节点有一个颜色。求对每一个颜色,至少经过一个该颜色节点的简单路径数量。
分析:虽然有O(n)的做法,但是这里还是贴一下虚树的做法。虚树的做法大概是:对每一种颜色建立虚树,对于每一个标记好的节点,分别统计其子树的非标记节点联通块大小,总数减去这样的情况。这里是通过子树大小减去子树中标记节点的子树大小来统计的。
#include<bits/stdc++.h> #define all(x) x.begin(),x.end() #define fi first #define sd second #define lson (nd<<1) #define rson (nd+nd+1) #define PB push_back #define mid (l+r>>1) #define MP make_pair #define SZ(x) (int)x.size() using namespace std; typedef long long LL; typedef vector<int> VI; typedef pair<int,int> PII; inline int read(){ int res=0, f=1;char ch=getchar(); while(ch<'0'|ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){res=res*10+ch-'0';ch=getchar();} return res*f; } const int MAXN = 200'005; const int MOD = 1000000007; void addmod(int& a, int b){a+=b;if(a>=MOD)a-=MOD;} int mulmod(int a, int b){return 1ll*a*b%MOD;} template<typename T> void chmin(T& a, T b){if(a>b)a=b;} template<typename T> void chmax(T& a, T b){if(b>a)a=b;} #define go(e,u) for(int e=head[u];e;e=Next[e]) int to[MAXN<<1],Next[MAXN<<1],head[MAXN],tol; void add_edge(int u,int v){ Next[++tol]=head[u];to[tol]=v;head[u]=tol; Next[++tol]=head[v];to[tol]=u;head[v]=tol; } #define gov(e,u) for(int e=headv[u];e;e=Nextv[e]) int tov[MAXN<<1],Nextv[MAXN<<1],headv[MAXN],tolv; void add_edgev(int u,int v){ Nextv[++tolv]=headv[u];tov[tolv]=v;headv[u]=tolv; } int n, col[MAXN]; vector<int> nodes[MAXN]; int dfn[MAXN], R[MAXN], dfncnt; int up[MAXN][25], dep[MAXN], st[MAXN], top, sz[MAXN]; int mark[MAXN]; LL ans; void dfs(int u, int f){ sz[u]=1; dfn[u]=++dfncnt; for(int i=0;up[u][i];++i)up[u][i+1]=up[up[u][i]][i]; go(e,u){ int v=to[e]; if(v==f)continue; up[v][0]=u; dep[v]=dep[u]+1; dfs(v,u); sz[u]+=sz[v]; } R[u]=dfncnt; } int getLCA(int u, int v){ if(dep[u]<dep[v])swap(u,v); for(int i=20;i>=0;--i){ if(dep[up[u][i]]>=dep[v]){ u=up[u][i]; } } if(u==v)return u; for(int i=20;i>=0;--i){ if(up[u][i]!=up[v][i]){ u=up[u][i]; v=up[v][i]; } } return up[u][0]; } bool cmp(int x, int y){ return dfn[x]<dfn[y]; } bool cmp2(PII x, PII y){//未排序,wa return x.fi<y.fi; } LL dfs1(int u){ LL s=0; vector<PII> num; gov(e,u){ int v=tov[e]; LL t=dfs1(v); s+=t; if(mark[u]) num.PB(MP(dfn[v],t)); } sort(all(num),cmp2); if(mark[u]){ int idx=0; go(e,u){ int v=to[e]; if(v==up[u][0])continue; int cc=0; while(idx<SZ(num)&&num[idx].fi>=dfn[v]&&num[idx].fi<=R[v]){ cc+=num[idx].sd; ++idx; } ans-=1ll*(sz[v]-cc)*(sz[v]-cc+1)/2; } } if(u==1&&!mark[u]){ LL num=sz[1]-s; ans-=1ll*num*(num+1)/2; } if(mark[u])return sz[u]; else return s; } int main(){ n=read(); for(int i=1;i<=n;++i){ col[i]=read(); nodes[col[i]].PB(i); } for(int i=1;i<n;++i){ int u=read(),v=read(); add_edge(u,v); } dep[1]=1; dfs(1,0); for(int color=1;color<=n;++color){ if(!SZ(nodes[color])){ cout<<0<<endl; continue; } ans=1ll*n*(n+1)/2; sort(all(nodes[color]),cmp); //建立虚树 st[top=1]=1;headv[1]=0;tolv=0; for(int i=0;i<SZ(nodes[color]);++i){ int nn=nodes[color][i]; mark[nn]=1; if(nn==1)continue; int l=getLCA(st[top],nn); if(l!=st[top]){ while(dfn[l]<dfn[st[top-1]]){ add_edgev(st[top-1],st[top]); --top; } if(dfn[l]>dfn[st[top-1]]){ headv[l]=0;add_edgev(l,st[top]);st[top]=l; }else{ add_edgev(l,st[top--]); } } headv[nn]=0;st[++top]=nn; } for(int i=1;i<top;++i){ add_edgev(st[i],st[i+1]); } dfs1(st[1]); cout<<ans<<endl; for(int i=0;i<SZ(nodes[color]);++i)mark[nodes[color][i]]=0; } return 0; }
BTW,O(n)的做法。
#include<bits/stdc++.h> #define all(x) x.begin(),x.end() #define fi first #define sd second #define lson (nd<<1) #define rson (nd+nd+1) #define PB push_back #define mid (l+r>>1) #define MP make_pair #define SZ(x) (int)x.size() using namespace std; typedef long long LL; typedef vector<int> VI; typedef pair<int,int> PII; inline LL read(){ LL res=0, f=1;char ch=getchar(); while(ch<'0'|ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){res=res*10+ch-'0';ch=getchar();} return res*f; } const int MAXN = 200'005; const int MOD = 1000000007; void addmod(int& a, int b){a+=b;if(a>=MOD)a-=MOD;} int mulmod(int a, int b){return 1ll*a*b%MOD;} template<typename T> void chmin(T& a, T b){if(a>b)a=b;} template<typename T> void chmax(T& a, T b){if(b>a)a=b;} LL n; LL sz[MAXN], sum[MAXN], ans[MAXN]; LL col[MAXN]; #define go(e,u) for(int e=head[u];e;e=Next[e]) int to[MAXN<<1],Next[MAXN<<1],head[MAXN],tol; void add_edge(int u,int v){ Next[++tol]=head[u];to[tol]=v;head[u]=tol; Next[++tol]=head[v];to[tol]=u;head[v]=tol; } LL calc(LL x){return x*(x+1)/2;} void dfs(int u,int f){ int c=col[u]; sz[u]=1;LL o=sum[c]; go(e,u){ int v=to[e]; if(v==f)continue; LL t=sum[c]; dfs(v,u); ans[c]-=calc(sz[v]-(sum[c]-t)); sz[u]+=sz[v]; } sum[col[u]]=o+sz[u]; } int main(){ n=read(); for(int i=1;i<=n;++i)col[i]=read(),ans[i]=n*(n+1)/2; for(int i=1,u,v;i<n;++i){ u=read(); v=read(); add_edge(u,v); } dfs(1,0); for(int i=1;i<=n;++i){ LL t=n-sum[i]; ans[i]-=calc(t); cout<<ans[i]<<endl; } return 0; }