题意
给一棵(n)个节点的树,每个节点上有一个颜色,颜色种数至多(k)种。
求树上有多少条路径满足包含所有颜色。
(nle5*10^4,kle10)
sol
树上路径想到淀粉质,(kle10)想到状态压缩。
考虑过重心的路径,用一个二进制状态表示这条路径上包含的颜色集合。
对于当前一个颜色集合i,只要找到一个S^i和他配对就行了
吗?
显然可以是S^i这个集合的超集。
枚举超集?时间爆炸。
上网学了一发神奇的高维前缀和,可以解决这类的超集和子集和的问题
超集和
for (int j=0;j<k;++j)
for (int i=all;i>=0;--i)
if (((1<<j)&i)==0) tot[i]+=tot[i|(1<<j)];
子集和
for (int j=0;j<k;++j)
for (int i=0;i<=all;++i)
if ((1<<j)&i) tot[i]+=tot[i^(1<<j)];
不明觉厉
upt2018-05-23:
其实高维前缀和就是FWT,虽然算法实现略有不同,但是底层算法是一样的,而且复杂度也是一样的。
所以这题的正解是点分治+FWT
然后这题就做完啦~
code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 5e4+5;
int n,k,all,val[N],to[N<<1],nxt[N<<1],head[N],cnt;
int sz[N],w[N],root,sum,vis[N],tot[1030],o[N];
ll ans;
void getroot(int u,int f)
{
sz[u]=1;w[u]=0;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]||v==f) continue;
getroot(v,u);
sz[u]+=sz[v];w[u]=max(w[u],sz[v]);
}
w[u]=max(w[u],sum-sz[u]);
if (w[u]<w[root]) root=u;
}
void getstatus(int u,int f,int sta)
{
o[++cnt]=sta;++tot[sta];
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]||v==f) continue;
getstatus(v,u,sta|val[v]);
}
}
ll calc(int u,int sta)
{
cnt=0;
memset(tot,0,sizeof(tot));
getstatus(u,0,sta);
for (int j=0;j<k;++j)
for (int i=all;i>=0;--i)
if (((1<<j)&i)==0) tot[i]+=tot[i|(1<<j)];
ll res=0;
for (int i=1;i<=cnt;++i) res+=tot[o[i]^all];
return res;
}
void solve(int u)
{
ans+=calc(u,val[u]);vis[u]=1;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]) continue;
ans-=calc(v,val[u]|val[v]);
sum=sz[v];root=0;
getroot(v,0);
solve(root);
}
}
int main()
{
while (scanf("%d%d",&n,&k)!=EOF)
{
memset(vis,0,sizeof(vis));
memset(head,0,sizeof(head));
cnt=0;all=(1<<k)-1;ans=0;
for (int i=1;i<=n;++i) val[i]=1<<(gi()-1);
for (int i=1,u,v;i<n;++i)
{
u=gi();v=gi();
to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;
to[++cnt]=u;nxt[cnt]=head[v];head[v]=cnt;
}
root=0;
sum=w[0]=n;
getroot(1,0);
solve(root);
printf("%lld
",ans);
}
}