zoukankan      html  css  js  c++  java
  • 51Nod1868 彩色树 虚树

    原文链接https://www.cnblogs.com/zhouzhendong/p/51Nod1868.html

    题目传送门 - 51Nod1868

    题意

      给定一颗 $n$个点的树,每个点一个 $[1,n]$ 的颜色。设 $g(x,y)$ 表示 $x$ 到 $y$ 的树上路径上有几种颜色。

      对于一个长度为 $n$ 的排列 $P[1cdots n]$ ,定义 $f(P)=sum_{i=1}^{n-1}g(P_i,P_{i+1})$ 。

      现在求对于 $n!$ 个排列,他们的 $f(P)$ 之和 对 $10^9+7$ 取模后的值。

    题解

      首先我们考虑每一个 $g(x,y)$ 对于答案的贡献次数。

      考虑捆绑法,把 $x$ 和 $y$ 看作一个整体,显然,它对答案的贡献次数为 $(n-1)!$ 。

      于是答案就是

    $$2 imes (n-1)!sum_{x=1}^{n}sum_{y=x+1}^{n} g(x,y)$$

      前面的 $2 imes (n-1)!$ 很好办,现在主要要求后面的那个。

      我们考虑对于每一个颜色分别处理。我们需要求出每一个颜色对答案的贡献。

      记 $f(c,x,y)$ 表示路径 $x$~$y$ 上,如果有颜色 $c$ ,那么值为 $1$ ,否则为 $0$ 。则后面一半变成了:

    $$sum_{c=1}^{n}sum_{x=1}^{n}sum_{y=x+1}^{n} f(c,x,y)$$

      确定一种颜色之后,后面的显然非常好求,直接一个树形dp 就差不多了。但是这样的时间复杂度是炸掉的。于是我们需要一个数据结构来优化——虚树。

      建出虚树,然后我们注意一下细节,统计一下就可以了。

      这里推荐一个写的比较详细的虚树学习笔记:https://www.k-xzy.xyz/archives/4476

    代码

    #include <bits/stdc++.h>
    using namespace std;
    typedef long long LL;
    const int N=200005,mod=1e9+7;
    int read(){
    	int x=0;
    	char ch=getchar();
    	while (!isdigit(ch))
    		ch=getchar();
    	while (isdigit(ch))
    		x=(x<<1)+(x<<3)+ch-48,ch=getchar();
    	return x;
    }
    struct Gragh{
    	static const int M=N*2;
    	int cnt,y[M],nxt[M],fst[N];
    	void clear(){
    		cnt=0;
    		memset(fst,0,sizeof fst);
    	}
    	void add(int a,int b){
    		y[++cnt]=b,nxt[cnt]=fst[a],fst[a]=cnt;
    	}
    }g,t;
    int n,c[N],Fac[N],Time=0,now_color,ans=0;
    int dfn[N],depth[N],size[N],fa[N][18],sqrsum[N];
    int dirson[N],tot[N],st[N],top;
    vector <int> id[N];
    LL calc(int x){
    	return 1LL*x*(x-1)/2;
    }
    void dfs(int x,int pre,int d){
    	dfn[x]=++Time,depth[x]=d,size[x]=1,fa[x][0]=pre,sqrsum[x]=0;
    	for (int i=1;i<18;i++)
    		fa[x][i]=fa[fa[x][i-1]][i-1];
    	for (int i=g.fst[x];i;i=g.nxt[i])
    		if (g.y[i]!=pre){
    			int y=g.y[i];
    			dfs(y,x,d+1);
    			size[x]+=size[y];
    			sqrsum[x]=(calc(size[y])+sqrsum[x])%mod;
    		}
    }
    int LCA(int x,int y){
    	if (depth[x]<depth[y])
    		swap(x,y);
    	for (int i=17;i>=0;i--)
    		if (depth[x]-(1<<i)>=depth[y])
    			x=fa[x][i];
    	if (x==y)
    		return x;
    	for (int i=17;i>=0;i--)
    		if (fa[x][i]!=fa[y][i])
    			x=fa[x][i],y=fa[y][i];
    	return fa[x][0];
    }
    bool cmp(int a,int b){
    	return dfn[a]<dfn[b];
    }
    void solve(int x){
    	int dx=dirson[x],sonsqr=tot[x]=0;
    	for (int k=t.fst[x];k;k=t.nxt[k]){
    		int y=t.y[k],&dy=dirson[y]=y;
    		for (int i=17;i>=0;i--)
    			if (depth[dy]-(1<<i)>depth[x])
    				dy=fa[dy][i];
    		solve(y);
    		tot[x]+=tot[y];
    		sonsqr=(calc(tot[y])+sonsqr)%mod;
    	}
    	if (c[x]==now_color){
    		tot[x]=size[x];
    		int xsqr=(calc(size[dx]-size[x])+sqrsum[x])%mod;
    		ans=(calc(size[dx])-xsqr+ans)%mod;
    	}
    	else {
    		ans=(calc(tot[x])-sonsqr+ans)%mod;
    		for (int i=t.fst[x];i;i=t.nxt[i]){
    			int y=t.y[i],v=size[dx]-tot[x]+tot[y]-size[dirson[y]];
    			ans=(1LL*tot[y]*v+ans)%mod;
    		}
    	}
    }
    int main(){
    	n=read();
    	for (int i=Fac[0]=1;i<=n;i++)
    		c[i]=read(),Fac[i]=1LL*Fac[i-1]*i%mod;
    	g.clear();
    	for (int i=1;i<n;i++){
    		int a=read(),b=read();
    		g.add(a,b);
    		g.add(b,a);
    	}
    	dfs(1,0,0);
    	for (int i=1;i<=n;i++)
    		id[i].clear();
    	for (int i=1;i<=n;i++)
    		id[c[i]].push_back(i);
    	t.clear();
    	for (int k=1;k<=n;k++){
    		if (id[k].size()<1)
    			continue;
    		sort(id[k].begin(),id[k].end(),cmp);
    		st[top=1]=1,t.fst[1]=0;
    		for (vector <int> :: iterator i=id[k].begin();i!=id[k].end();i++){
    			int x=*i;
    			if (x==1)
    				continue;
    			int lca=LCA(x,st[top]);
    			if (lca!=st[top]){
    				while (depth[st[top-1]]>depth[lca])
    					t.add(st[top-1],st[top]),top--;
    				if (st[top-1]!=lca)
    					t.fst[lca]=0,t.add(lca,st[top]),st[top]=lca;
    				else
    					t.add(lca,st[top--]);
    			}
    			t.fst[x]=0,st[++top]=x;
    		}
    		for (int i=1;i<top;i++)
    			t.add(st[i],st[i+1]);
    		now_color=k,dirson[1]=1;
    		solve(1);
    	}
    	printf("%d
    ",2LL*(ans+mod)%mod*Fac[n-1]%mod);
    	return 0;
    }
    

      

  • 相关阅读:
    数据库期末考试复习
    函数 初识
    文件操作
    深浅copy 和 集合
    数据编码补充
    字典的增删改查和嵌套
    面试题 和 python 2与3的期区别
    英文练习
    初识数据类型
    测试基础-系统测试(2)
  • 原文地址:https://www.cnblogs.com/zhouzhendong/p/51Nod1868.html
Copyright © 2011-2022 走看看