题目链接
AtCoder:https://agc008.contest.atcoder.jp/tasks/agc008_f
洛谷:https://www.luogu.org/problemnew/show/AT2268
Solution
首先假设所有点都是黑的。
设(f(i,d))表示(i)节点扩展(k)步的点集,那么答案就是本质不同的点集个数。
我们考虑一个很巧妙的计数方法:每种点集都在(d)最小时被算一次,那么二元组一定要满足这样的性质:
- 首先我们硬点全集不选,答案最后加一。
- 对于((x,d)),我们要求所有于(x)相邻的点(y)都不存在(f(x,d)=f(y,d-1))。
那么我们可以发现每个点都有一个选取上界,这个(d)满足以下性质:
- (din [0,dis_x-1]),其中(dis_x)表示离(x)最远点的距离。
- (din [0,dis2_v+1]),其中(v)为(x)的儿子,(dis2_v)表示(x)不经过(v)的(dis)最大值。
这个画个图就可以知道。
那么如果有一些点不是黑的,我们考虑给这些点定个下界,下界就是以(x)为根(x)的儿子的子树中含有黑点的子树的(dis_1)的最小值,这样就可以保证这种方案可以被一个黑点产生。
然后( m tree dp)实现就好了,复杂度(O(n))。
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('
');}
#define lf double
#define ll long long
#define pii pair<int,int >
#define vec vector<int >
#define pb push_back
#define mp make_pair
#define fr first
#define sc second
const int maxn = 5e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
char s[maxn];
int sz[maxn],d1[maxn],d2[maxn],d3[maxn],d4[maxn],n,head[maxn],tot,f[maxn];
struct edge{int to,nxt;}e[maxn<<1];
void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}
void dfs(int x,int fa) {
sz[x]=s[x]-'0',f[x]=fa;d3[x]=1e9;
for(int v,i=head[x];i;i=e[i].nxt)
if((v=e[i].to)!=fa) {
dfs(v,x),sz[x]+=sz[v];
d1[x]=max(d1[x],d1[v]+1);
if(sz[v]) d3[x]=min(d3[x],d1[e[i].to]+1);
}
}
void dfs2(int x,int fa) {
int fr=0,sc=0;if(fa) d4[x]=d2[x]-1;
for(int v,i=head[x];i;i=e[i].nxt) {
if((v=e[i].to)==fa) continue;
if(d1[v]+1>=fr) sc=fr,fr=d1[v]+1;
else if(d1[v]+1>sc) sc=d1[v]+1;
}
for(int v,i=head[x];i;i=e[i].nxt) {
if((v=e[i].to)==fa) continue;
if(d1[v]+1==fr) d2[v]=max(d2[x],sc)+1;
else d2[v]=max(d2[x],fr)+1;
dfs2(e[i].to,x);
}
}
int main() {
read(n);for(int i=1,x,y;i<n;i++) read(x),read(y),ins(x,y);
scanf("%s",s+1);dfs(1,0),dfs2(1,0);
ll ans=0;int mx,mn;
for(int x=1;x<=n;x++) {
mx=max(d1[x],d2[x])-1;
if(s[x]=='0') mn=min(d3[x],sz[1]==sz[x]?(int)1e9:d2[x]);else mn=0;
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to==f[x]) mx=min(mx,d1[x]+1);
else mx=min(mx,d4[e[i].to]+1);
if(mx>=mn) ans+=(ll)mx-mn+1;
}printf("%lld
",ans+1ll);
return 0;
}