zoukankan      html  css  js  c++  java
  • HDU 5977 Garden of Eden (树形dp+快速沃尔什变换FWT)

    CGZ大佬提醒我,我要是再不更博客可就连一月一更的频率也没有了。。。

    emmm,正好做了一道有点意思的题,就拿出来充数吧=。=

    题意

    一棵树,有 $ n (nleq50000) $ 个节点,每个点都有一个颜色,共有 $ k(kleq10) $ 种颜色,问有多少条路径可以遍历到所有 $ k $ 种颜色?(一条路径交换起点终点就算两条哦)

    做法

    事实证明,连我都能不看题解想出来的题果然都是水题qwq

    我是从CJ的xzyxzy大佬上的博客上看到这道题的,所以就理所当然用FWT做了...然后才发现网上的题解都是点分治...Menhera大佬提供了一个更优的做法,不过我是真的看不懂...放在最后讲一下(在代码后面)。

    这道题一眼就是树形dp,而且k特别小,貌似可以状压?

    用二进制数 $ S $ 表示一条路径上的颜色种类,用 $ dp[i][S] $ 表示当前节点 $ i $ 到它下面的叶子节点中,颜色状态为 $ S $ 的路径的数量。很显然 $ S=2^k-1 $ 的路径就是我们要找的路径,我们的目标就是求出这样的路径的数量ヾ(゚∀゚ゞ)!

    求出 $ dp[i][S] $ 是很容易的,只需要遍历一遍就行了。然而,有的路径的两端会在 $ i $ 的两个子树中而横跨 $ i $ 这个结点,这样的路径怎么统计呢?总不能一个个枚举吧。。这就该FWT上场了!把要统计的两个子树的 $ dp[x][S] $ 做or卷积,然后把 $ S=111...1 $ 的路径条数累加进答案就可以啦!注意,FWT是在辅助数组中进行的,不应该改变原数组。在一遍FWT后,将第二个儿子的路径数累加进第一个儿子的路径数,接着将第三个儿子又与第一个儿子做FWT,以此类推,即可求出所有横跨各个子树的路径了。这样做时间复杂度是 $ O(n2^kk) $ ,而这道题的时限是5s,所以还是可以轻松跑过的。

    emmm...(在打了一会代码之后)

    等等,哪里有点问题...树上有50000个点,每个点需要大小为1024的数组来存储状态,我需要开整整195MB的数组?!

    经过一番思考...我终于发现了这样的方法:先一遍dfs求出每个节点的重儿子,dp时优先递归重儿子,然后递归别的儿子,一遍FWT求出答案后再将两个儿子的状态合并,回收轻儿子的空间,在接着递归别的儿子。这样,可以证明在某一时刻最多同时存在 $ log_2n $ 个儿子的状态(与树剖的证明相似),所以空间就只需要开一点点啦~

    代码:(有很详细的注释哦qwq不会的可以看一下代码)

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <stack>
    #define R register
    using namespace std;
    typedef long long LL;
    const int MAXN=50100;
    const int MAXM=1100;
    int he[MAXN],col[MAXN];
    int siz[MAXN],son[MAXN];
    int dp[100][MAXM];
    int n,k,cnt,len;
    LL ans;//注意ans是有可能爆int的
    
    template<class T>int read(T &x)//这是zyf看了会沉默的可以判EOF的快读
    {
    	x=0;int ff=0;char ch=getchar();
    	while((ch<'0'||ch>'9')&&ch!=EOF){ff|=(ch=='-');ch=getchar();}
    	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    	x=ff?-x:x;
    	if(ch==EOF)return EOF;
    	return 1;
    }
    
    class FWT
    {
    private:
    	LL tem1[MAXM],tem2[MAXM];
    	
    	void fwt_or(LL *a,int f)//普通的FWT,但是zyf说or和and卷积都应该叫做快速莫比乌斯变换(FMT)?
    	{
    		for(R int i=1;i<len;i<<=1)
    		  for(R int j=0;j<len;j+=(i<<1))
    		    for(R int k=0;k<i;++k)
    		      if(f==1)a[j+k+i]+=a[j+k];
    		      else a[j+k+i]-=a[j+k];
    	}
    	
    public:
    	int fwt(int *a,int *b)
    	{
    		for(R int i=0;i<len;++i)tem1[i]=a[i];
    		for(R int i=0;i<len;++i)tem2[i]=b[i];
    		fwt_or(tem1,1);
    		fwt_or(tem2,1);
    		for(R int i=0;i<len;++i)tem1[i]*=tem2[i];
    		fwt_or(tem1,-1);
    		return (int)tem1[len-1];
    	}
    }fwt;
    
    struct edge
    {
    	int to,next;
    }ed[MAXN<<1];
    
    void added(int x,int y)//加边,常规操作
    {
    	ed[++cnt].to=y;
    	ed[cnt].next=he[x];
    	he[x]=cnt;
    }
    
    void dfs_pre(int x,int fa)//求重儿子
    {
    	siz[x]=1,son[x]=0;
    	for(int i=he[x];i;i=ed[i].next)
    	{
    		int to=ed[i].to;
    		if(to==fa)continue;
    		dfs_pre(to,x);
    		siz[x]+=siz[to];
    		if(siz[to]>siz[son[x]])son[x]=to;
    	}
    }
    
    stack<int>stk;//用于回收空间
    int dfs(int x,int fa)
    {
    	int num,bt=1<<col[x];
    	if(son[x])num=dfs(son[x],x);//继承重儿子的空间
    	else//是叶节点
    	{
    		if(!stk.empty())//使用回收后的空间
    		  num=stk.top(),stk.pop();
    		else num=++cnt;//使用新空间
    		dp[num][bt]=1;
    		return num;//上传空间给父亲
    	}
    	for(R int i=1;i<len;++i)//根据重儿子的状态推出自己的状态
    	  if(dp[num][i])
    	    if(!(i&bt))
    	      dp[num][i|bt]+=dp[num][i],dp[num][i]=0;
    	ans+=dp[num][len-1];
    	dp[num][bt]+=1;
    	for(int j=he[x];j;j=ed[j].next)//枚举轻儿子们
    	{
    		int to=ed[j].to;
    		if(to==fa||to==son[x])continue;
    		int tnum=dfs(to,x);//得到轻儿子的状态
    		ans+=fwt.fwt(dp[num],dp[tnum]);//FWT并累加答案
    		for(R int i=1;i<len;++i)//将这个轻儿子的状态合并至自己的状态
    		  if(dp[tnum][i])
    			if(!(i&bt))
    			  dp[num][i|bt]+=dp[tnum][i],dp[tnum][i]=0;
    			else dp[num][i]+=dp[tnum][i],dp[tnum][i]=0;
    		stk.push(tnum);//空间回收
    	}
    	return num;//上传空间给父亲
    }
    
    int main()
    {
    	while(read(n)!=EOF)
    	{
    		memset(he,0,sizeof(he));
    		memset(dp,0,sizeof(dp));
    		while(!stk.empty())stk.pop();//各种清零不要忘
    		read(k);
    		len=1<<k;
    		for(R int i=1;i<=n;++i)
    		  read(col[i]),--col[i];
    		int t1,t2;
    		cnt=0,ans=0;
    		for(R int i=1;i<n;++i)//加边
    		{
    			read(t1),read(t2);
    			added(t1,t2);
    			added(t2,t1);
    		}
    		if(k==1)//特判就好了,可以省很多事
    		{
    			ans=1ll*n*(n-1);
    			ans+=n;
    			printf("%lld
    ",ans);
    			continue;
    		}
    		dfs_pre(1,0);
    		cnt=0;//这里cnt被用于标记当前使用的空间
    		dfs(1,0);
    		ans<<=1;//别忘了交换起点终点后的路径算的不同路径
    		printf("%lld
    ",ans);
    	}
    	return 0;
    }
    

    什么?优化?

    我一个在明德的好朋友Menhera酱发现可以把复杂度中的 $ k $ 去掉,变成 $ O(n2^k) $ ,具体貌似是利用“基”的形式进行 $ O(2^k) $ 的FWT,具体可以看她的这篇博客:https://www.cnblogs.com/Menhera/p/9514412.html (那不足50行的代码真是震撼我心)

    Menhera:“我的这个做法是什么点分治的优化版,而你的是弱化版~”(事实是我1700MS她1200MS。。。明明去了一个log然而感觉就像卡了常一样)

    只感觉智商被碾压qwq。。。Menhera太强了orz

  • 相关阅读:
    回文字符串问题
    Linux添加nfs共享存储盘
    解读nginx配置
    制作自己的nginx rpm包
    linux编译安装时常见错误解决办法
    redis单机及集群安装
    nginx ssl
    vsftp配置详解
    Linux-文件系统的简单操作
    Linux-Vim编辑器
  • 原文地址:https://www.cnblogs.com/sclbgw7/p/9508235.html
Copyright © 2011-2022 走看看