最近调题都好艰辛啊……(请忽略弱者的一句感叹)
虚树,为关键点及其(lca)构成的树,在这棵树上跑(dp)有时可以降低复杂度。
若我有(m)个关键点,以下给出虚树建树模板:
st[++top]=1;beg[1]=0;
for(int i=1;i<=m;i++){
if(a[i]==1)continue;
int anc=lca(a[i],st[top]);
if(anc!=st[top]){
while(top>1&&dep[st[top-1]]>dep[anc])
add(st[top-1],st[top]),top--;
if(st[top-1]]!=anc)beg[anc]=0,add(anc,st[top]),st[top]=anc;
else add(anc,st[top--]);
}
beg[a[i]]=0;st[++top]=a[i];
}
for(int i=1;i<top;i++)
add(st[top],st[top+1]);
然后在虚树上(dp)。
若此点为关键点,则任意子关键点都有1的贡献(必须切断中间的一个点)。
否则,若有多于一个关键子节点,则产生1的贡献(必须切断此点)。
代码如下,仅供参考:
#include<bits/stdc++.h>
using namespace std;
inline int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
const int maxn=1e5+10;
int n,m,cnt,ans,a[maxn],val[maxn];
int beg[maxn],nex[maxn<<1],to[maxn<<1],e;
inline void add(int x,int y){
e++;nex[e]=beg[x];
beg[x]=e;to[e]=y;
}
int dep[maxn],dfn[maxn],f[maxn][20];
inline void dfs(int x,int fa){
dep[x]=dep[fa]+1;
f[x][0]=fa;dfn[x]=++cnt;
for(int i=1;i<=19;i++)
f[x][i]=f[f[x][i-1]][i-1];
for(int i=beg[x];i;i=nex[i])
if(to[i]!=fa)dfs(to[i],x);
}
int st[maxn],top;
inline int cmp(int x,int y){return dfn[x]<dfn[y];}
inline int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=19;i>=0;i--)
if(dep[f[x][i]]>=dep[y])x=f[x][i];
if(x==y)return x;
for(int i=19;i>=0;i--)
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
return f[x][0];
}
int siz[maxn],dp[maxn];
inline void solve(int x){
siz[x]=dp[x]=0;
for(int i=beg[x];i;i=nex[i]){
int t=to[i];
solve(t);
siz[x]+=siz[t];
dp[x]+=dp[t];
}
if(val[x])dp[x]+=siz[x],siz[x]=1;
else if(siz[x]>1)siz[x]=0,dp[x]++;
}
int main(){
n=read();
int x,y;
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y),add(y,x);
}
dfs(1,0);
memset(beg,0,sizeof(beg));e=0;
m=read();
for(int i=1;i<=m;i++){
cnt=read();top=e=0;
for(int j=1;j<=cnt;j++)
a[j]=read(),val[a[j]]=1;
int flag=0;
for(int j=1;j<=cnt;j++)
if(val[f[a[j]][0]])flag=1;
if(flag){
for(int j=1;j<=cnt;j++)
val[a[j]]=0;
puts("-1");
continue;
}
sort(a+1,a+1+cnt,cmp);
st[++top]=1;beg[1]=0;
for(int j=1;j<=cnt;j++){
if(a[j]==1)continue;
int anc=lca(a[j],st[top]);
if(anc!=st[top]){
while(top>1&&dep[st[top-1]]>dep[anc])
add(st[top-1],st[top]),top--;
if(st[top-1]!=anc)beg[anc]=0,add(anc,st[top]),st[top]=anc;
else add(anc,st[top--]);
}
beg[a[j]]=0;st[++top]=a[j];
}
for(int j=1;j<top;j++)
add(st[j],st[j+1]);
solve(1);
printf("%d
",dp[1]);
for(int j=1;j<=cnt;j++)
val[a[j]]=0;
}
return 0;
}
深深地感到自己的弱小。