zoukankan      html  css  js  c++  java
  • P4103 [HEOI2014]大工程 虚树

    题意:

    戳这里

    分析:

    虚树板子题

    首先有一个 (O(qn)) 的暴力,就是对于每一次询问, (O(n)) 的树上 DP ,我们统计一下每一个点,它的子树内离它最近/远的关键点的距离,已经关键点的个数

    对于第一个询问等价于 (sum dep(x)+dep(y)-sum2 imes dep(lca))

    我们 (dp) 的时候顺便统计一下每一个点作为 (lca) 出现了多少次,这个直接扫一下儿子就能得到

    第二个询问按照我们 (dp) 数组记下的状态枚举一下两个子树就可以得到

    我们发现这种 树上(DP) 多次询问每次给定点集(点集总和与 (n) 同阶) 的问题直接建出虚树这样每次 (dp) 的复杂度降低到和点数同阶,总的复杂度不超过 (O(nlog))

    代码:

    #include<bits/stdc++.h>
    
    using namespace std;
    
    namespace zzc
    {
        int read()
        {
            int x=0,f=1;char ch=getchar();
            while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
            while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}
            return x*f;
        }
    
        const int maxn = 1e6+5;
        const int inf = 0x3f3f3f3f;
        int n,idx,num,top,qt,ans2,ans3,mx[maxn],mn[maxn];
        long long ans1;
        int dfn[maxn],fa[maxn][22],st[maxn],p[maxn],dep[maxn],sum[maxn];
        bool vis[maxn];
    
        struct tree
        {
            int cnt,head[maxn];
            struct edge
            {
                int to,nxt;
            }e[maxn<<1];
    
            void add(int u,int v)
            {
                e[++cnt].to=v;
                e[cnt].nxt=head[u];
                head[u]=cnt;
                
                e[++cnt].to=u;
                e[cnt].nxt=head[v];
                head[v]=cnt;
            }
        }t1,t2;
        
        bool cmp(int x,int y)
        {
            return dfn[x]<dfn[y];
        }
    
        void dfs1(int u,int ff)
        {
            dfn[u]=++idx;fa[u][0]=ff;dep[u]=dep[ff]+1;
            for(int i=t1.head[u];i;i=t1.e[i].nxt)
            {
                int v=t1.e[i].to;
                if(v!=ff) dfs1(v,u);
            }
        }
        
        inline int lca(int x,int y)
        {
            if(dep[x]<dep[y]) swap(x,y);
            for(int i=21;i>=0;i--) if(dep[fa[x][i]]>=dep[y]) x=fa[x][i];
            if(x==y) return x;
            for(int i=21;i>=0;i--)
            {
                if(fa[x][i]!=fa[y][i])
                {
                    x=fa[x][i];
                    y=fa[y][i];
                }
            }
            return fa[x][0];
        }
    
        inline void build()
        {
            sort(p+1,p+num+1,cmp);t2.cnt=0;
            st[top=1]=1;t2.head[1]=0;
            for(int i=1;i<=num;i++)
    		{
    			if(p[i]!=1)
    			{
    				int x=lca(p[i],st[top]);
    				if(x!=st[top])
    				{
    					while(top>1&&dfn[x]<dfn[st[top-1]]) t2.add(st[top-1],st[top]),top--;
    					if(top>1&&dfn[x]!=dfn[st[top-1]])
    					{
    						t2.head[x]=0;
    						t2.add(x,st[top]);
    						st[top]=x;
    					}
    					else t2.add(x,st[top--]);
    				}
    				st[++top]=p[i];t2.head[p[i]]=0;
    			}
    			
    		}
    		while(top>1) t2.add(st[top-1],st[top]),top--;
        }
    	
    	void dfs2(int u,int ff)
    	{
    		sum[u]=0;mx[u]=-inf;mn[u]=inf;
    		if(vis[u]) mn[u]=0,mx[u]=0,sum[u]++;
    		for(int i=t2.head[u];i;i=t2.e[i].nxt)
    		{
    			int v=t2.e[i].to;
    			if(v!=ff)
    			{
    				dfs2(v,u);
    				ans1-=1ll*sum[u]*sum[v]*2*dep[u];
    				ans2=min(ans2,mn[u]+mn[v]-dep[u]+dep[v]);
    				ans3=max(ans3,mx[u]+mx[v]-dep[u]+dep[v]);
    				mx[u]=max(mx[u],mx[v]-dep[u]+dep[v]);
    				mn[u]=min(mn[u],mn[v]-dep[u]+dep[v]);
    				sum[u]+=sum[v];
    			}
    			
    		}
    	}
    	
    	inline void solve()
    	{
    		for(int i=1;i<=num;i++) ans1+=1ll*(num-1)*dep[p[i]];
    		dfs2(1,0);
    		printf("%lld %d %d
    ",ans1,ans2,ans3);
    	}
    	
        void work()
        {
            int a,b;
            n=read();
            for(int i=1;i<n;i++)
            {
                a=read();b=read();
                t1.add(a,b);
            }
    		dep[0]=-1;dfs1(1,0);
            for(int j=1;j<=21;j++)
            {
                for(int i=1;i<=n;i++)
                {
                    fa[i][j]=fa[fa[i][j-1]][j-1];
                }
            }
            qt=read();
            while(qt--)
            {
            	ans1=0;ans2=inf;ans3=-inf;
            	num=read();
            	for(int i=1;i<=num;i++) p[i]=read(),vis[p[i]]=true;
            	build();
            	solve();
            	for(int i=1;i<=num;i++) vis[p[i]]=false;
    		}
        }
    
    
    }
    
    int main()
    {
        zzc::work();
        return 0;
    }
    
  • 相关阅读:
    validation 参数效验框架
    小酌一下:Maven
    小酌一下:git 常用命令
    小酌一下:anaconda 基本操作
    小酌一下:Win10 解决fetch_20newsgroups下载速度巨慢
    学习笔记:Python3 异常处理
    学习笔记:Python3 面向对象
    学习笔记:Python3 函数式编程
    学习笔记:Python3 函数
    学习笔记:Python3 高级特性
  • 原文地址:https://www.cnblogs.com/youth518/p/14245868.html
Copyright © 2011-2022 走看看