zoukankan      html  css  js  c++  java
  • 【CF809E】Surprise me! 树形DP 虚树 数学

    题目大意

      给你一棵(n)个点的树,每个点有权值(a_i)(a)为一个排列,求

    [frac{1}{n(n-1)}sum_{i=1}^nsum_{j=1}^n varphi(a_ia_j)dist_{i,j} ]

      (nleq 200000)

    题解

      欧拉phi函数

    [egin{align} ans&=frac{1}{n(n-1)}sum_{i=1}^nsum_{j=1}^n varphi(a_ia_j)dist_{i,j}\ &=frac{1}{n(n-1)}sum_{i=1}^nsum_{j=1}^nsum_{d=(a_i,a_j)} frac{varphi(a_i)varphi(a_j)d}{varphi(d)}dist_{i,j}\ &=frac{1}{n(n-1)}sum_{d=1}^nfrac{d}{mu(d)}sum_{d=(a_i,a_j)}varphi(a_i)varphi(a_j)dist_{i,j}\ f(d)&=sum_{d=(a_i,a_j)}varphi(a_i)varphi(a_j)dist_{i,j}\ F(d)&=sum_{d|a_i,d|a_j}varphi(a_i)varphi(a_j)dist_{i,j}\ F(d)&=sum_{d|n}f(n)\ f(d)&=F(d)-sum_{d|n,d eq n}f(n) end{align} ]

      (F(d))可以直接建虚树DP求。

      然后直接反演统计就可以得到答案。

      总的点数是(sum_{i=1}^nlfloorfrac{n}{i} floor=O(nlog n))

      所以总的时间复杂度是(O(nlog^2 n))

    代码

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cstdlib>
    #include<ctime>
    #include<utility>
    using namespace std;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int,int> pii;
    ll p=1000000007;
    struct graph
    {
    	int h[200010];
    	int v[400010];
    	int w[400010];
    	int t[400010];
    	int n;
    	graph()
    	{
    		memset(h,0,sizeof h);
    		n=0;
    	}
    	void add(int x,int y,int z)
    	{
    		n++;
    		v[n]=y;
    		w[n]=z;
    		t[n]=h[x];
    		h[x]=n;
    	}
    };
    graph g,g2;
    int f[200010][20];
    int d[200010];
    int st[200010];
    int ti;
    void dfs(int x,int fa,int dep)
    {
    	f[x][0]=fa;
    	d[x]=dep;
    	st[x]=++ti;
    	int i;
    	for(i=1;i<=19;i++)
    		f[x][i]=f[f[x][i-1]][i-1];
    	for(i=g.h[x];i;i=g.t[i])
    		if(g.v[i]!=fa)
    			dfs(g.v[i],x,dep+1);
    }
    int getlca(int x,int y)
    {
    	if(d[x]<d[y])
    		swap(x,y);
    	int i;
    	for(i=19;i>=0;i--)
    		if(d[f[x][i]]>=d[y])
    			x=f[x][i];
    	if(x==y)
    		return x;
    	for(i=19;i>=0;i--)
    		if(f[x][i]!=f[y][i])
    		{
    			x=f[x][i];
    			y=f[y][i];
    		}
    	return f[x][0];
    }
    ll phi[200010];
    int b[200010];
    int pri[100010];
    int cnt;
    ll inv[200010];
    void init(int n)
    {
    	int i,j;
    	inv[0]=inv[1]=1;
    	for(i=2;i<=n;i++)
    		inv[i]=-(p/i)*inv[p%i]%p;
    	phi[1]=1;
    	cnt=0;
    	for(i=2;i<=n;i++)
    	{
    		if(!b[i])
    		{
    			pri[++cnt]=i;
    			phi[i]=i-1;
    		}
    		for(j=1;j<=cnt&&i*pri[j]<=n;j++)
    		{
    			b[i*pri[j]]=1;
    			if(i%pri[j]==0)
    			{
    				phi[i*pri[j]]=phi[i]*pri[j];
    				break;
    			}
    			phi[i*pri[j]]=phi[i]*phi[pri[j]];
    		}
    	}
    }
    ll a[200010];
    ll s[200010];
    int c[200010];
    int c1[200010];
    int ct;
    int n;
    int stack[200010];
    int top;
    int cmp(int a,int b)
    {
    	return st[a]<st[b];
    }
    ll s1[200010];
    ll s2[200010];
    ll sum;
    void add(int x,int y)//f[x]=y
    {
    	ll s3=(s1[x]+(d[x]-d[y])*s2[x])%p;
    	sum=(sum+s3*s2[y]+s1[y]*s2[x])%p;
    	s1[y]=(s1[y]+s3)%p;
    	s2[y]=(s2[y]+s2[x])%p;
    }
    ll solve(int x)
    {
    	sum=0;
    	ct=top=0;
    	int i;
    	for(i=x;i<=n;i+=x)
    		c1[++ct]=c[i];
    	sort(c1+1,c1+ct+1,cmp);
    	int rt=getlca(c1[1],c1[ct]);
    	if(rt!=c1[1])
    	{
    		stack[++top]=rt;
    		s1[rt]=s2[rt]=0;
    	}
    	for(i=1;i<=ct;i++)
    	{
    		 if(i>=2)
    		 {
    		 	int lca=getlca(c1[i],c1[i-1]);
    		 	while(d[stack[top]]>d[lca])
    		 		if(d[stack[top-1]]<d[lca])
    		 		{
    		 			s1[lca]=s2[lca]=0;
    		 			add(stack[top],lca);
    		 			stack[top]=lca;
    		 		}
    		 		else
    		 		{
    		 			add(stack[top],stack[top-1]);
    		 			top--;
    		 		}
    		 }
    		 stack[++top]=c1[i];
    		 s1[c1[i]]=0;
    		 s2[c1[i]]=phi[a[c1[i]]];
    	}
    	while(top>1)
    	{
    		add(stack[top],stack[top-1]);
    		top--;
    	}
    	return sum*2%p;
    }
    int main()
    {
    	scanf("%d",&n);
    	init(n);
    	int i,x,y,j;
    	for(i=1;i<=n;i++)
    	{
    		scanf("%lld",&a[i]);
    		c[a[i]]=i;
    	}
    	for(i=1;i<n;i++)
    	{
    		scanf("%d%d",&x,&y);
    		g.add(x,y,0);
    		g.add(y,x,0);
    	}
    	dfs(1,0,1);
    	for(i=1;i<=n;i++)
    		s[i]=solve(i);
    	ll ans=0;
    	for(i=n;i>=1;i--)
    	{
    		for(j=i+i;j<=n;j+=i)
    			s[i]-=s[j];
    		ans=(ans+s[i]*i%p*inv[phi[i]]%p)%p;
    	}
    	ans=ans*inv[n]%p*inv[n-1]%p;
    	ans=(ans+p)%p;
    	printf("%lld
    ",ans);
    	return 0;
    }
    
  • 相关阅读:
    youcompleteme-Vim补全插件安装
    depthimage_to_laserscan代码解读
    如何创建离线化 mapbox sprite精灵图
    mapbox/node-fontnik工具使用介绍
    跟我学习dubbo-使用Maven构建Dubbo服务的可执行jar包(4)
    跟我学习dubbo-Dubbo管理控制台的安装(3)
    跟我学习dubbo-ZooKeeper注册中心安装(2)
    跟我学习dubbo-简介(1)
    跟我学习dubbo-构建Dubbo服务消费者Web应用的war包并在Tomcat中部署(6)
    跟我学习dubbo-在Linux操作系统上手工部署Dubbo服务(5)
  • 原文地址:https://www.cnblogs.com/ywwyww/p/8511453.html
Copyright © 2011-2022 走看看