题目
题目链接:https://codeforces.com/contest/1481/problem/F
给定一棵 (n) 个节点的树,根为 (1) ,每个节点会分配到一个字符 a
或 b
。
要求整棵树中字符 a
的数量为 (x) ,字符 b
的数量为 (n-x) 。
定义节点 (v) 上的字符串:
- 若 (v) 是根节点,则 (v) 的上的字符串为根节点分配到的字符。
- 否则,(v) 上的字符串为父节点上的字符串的末尾加上 (v) 分配到的字符。
请为每个节点分配字符,在满足字符 a
,b
数量要求的前提下,使得所有节点上的字符串的种类最少。
(nleq 10^5)。
思路
神仙题。
答案的下界显然是树的高度。可以证明上界是树高 (+1)。
假设我们已经填完了 (1sim i-1) 层的所有节点。考虑第 (i) 层的节点。设这一层有 (k) 个非叶子节点,深度不小于 (i) 的有 (s) 个节点,那么显然 (kleq frac{s}{2})。
若此时有 (cnta) 个 a
未使用,(cntb) 个 b
未使用,那么必然有 (max(cnta,cntb)geq frac{s}{2}geq k)。所以这一层的非叶子节点都可以用同一种颜色。
再考虑这一层的叶子节点。假设这一层非叶子节点使用的是 a
,如果剩余 a
的数量不小于这一层叶子的数量,那么叶子也填上 a
,考虑下一层;如果剩余 a
的数量小于这一层叶子数量,那么直接把 a
用完,那么剩余所有点都用 b
填上。因为只有这一层的叶子是与非叶子不同的,所以不同的字符串数量恰好是树高 (+1)。
那么只需要判断答案是否能为树高即可。
如果答案恰好是树高,当且仅当存在某些层,节点的数量之和恰好为 (x)。因为每层节点数量的集合大小是 (O(sqrt{n})) 的,直接跑二进制分组完全背包判断是否有解就行了。
输出方案的话,如果答案为树高就背包记录路径;否则就按照上述方法构造。
时间复杂度 (O(nsqrt{n}))。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=100010,M=700;
int n,m,W,tot,maxd,ans[N],head[N],dep[N],cnt[N],w[N],v[N];
bool f[M][N];
vector<int> pos[N];
struct edge
{
int next,to;
}e[N];
void add(int from,int to)
{
e[++tot]=(edge){head[from],to};
head[from]=tot;
}
void dfs(int x,int fa)
{
dep[x]=dep[fa]+1; maxd=max(maxd,dep[x]);
pos[dep[x]].push_back(x);
for (int i=head[x];~i;i=e[i].next)
dfs(e[i].to,x);
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&W);
for (int i=2,x;i<=n;i++)
{
scanf("%d",&x);
add(x,i);
}
dfs(1,0);
for (int i=1;i<=maxd;i++)
cnt[pos[i].size()]++;
for (int i=1;i<=n;i++)
{
for (int j=0;j<=18;j++)
if (cnt[i]>=(1<<j))
m++,w[m]=(i<<j),v[m]=i,cnt[i]-=(1<<j);
if (cnt[i])
m++,w[m]=i*cnt[i],v[m]=i;
}
f[0][0]=1;
for (int j=1;j<=m;j++)
for (int i=0;i<=W;i++)
{
f[j][i]=f[j-1][i];
if (i>=w[j]) f[j][i]|=f[j-1][i-w[j]];
}
if (f[m][W])
{
cout<<maxd<<"
";
memset(cnt,0,sizeof(cnt));
for (int j=m,i=W;j;j--)
if (i>=w[j] && f[j-1][i-w[j]])
i-=w[j],cnt[v[j]]+=(w[j]/v[j]);
for (int i=1;i<=maxd;i++)
if (cnt[pos[i].size()])
{
cnt[pos[i].size()]--;
for (int j=0;j<(int)pos[i].size();j++)
ans[pos[i][j]]=1;
}
}
else
{
cout<<maxd+1<<"
";
int cnta=W,cntb=n-W,res=1;
for (int i=1;i<=maxd;i++)
{
if (cnta<cntb) swap(cnta,cntb),res^=1;
for (int j=0;j<(int)pos[i].size();j++)
{
int x=pos[i][j];
if (head[x]!=-1) ans[x]=res,cnta--;
}
for (int j=0;j<(int)pos[i].size();j++)
{
int x=pos[i][j];
if (!cnta) swap(cnta,cntb),res^=1;
if (head[x]==-1) ans[x]=res,cnta--;
}
}
}
for (int i=1;i<=n;i++)
putchar((ans[i]^1)+'a');
return 0;
}