首先尝试从初态走到 next 状态, 设根节点为 u,某个白色节点为 v, 将 u->v 染黑, 想象把这条路径缩成一点,它的子树们就是一个 next 状态:虽然是多个游戏的和。
一个自然的想法是 dfs, 然后暴力搞, 复杂度是 O(n^2)。
正解是对这个暴力的优化, 直接维护子游戏的集合,支持创建集合、全局异或、合并集合和查询mex, 重点在于实现这些操作的数据结构——0/1Trie。
最后, 这题的方方面面都与 CSP2019day1T3 的对于 O(n^3) 暴力的优化的思想特别相像, 比如暴力,比如正解,比如输出答案;但这个思想本质上也是个简单的东西:集合增添极少元素,某些东西可以快速维护。
若 Trie 要支持全局异或的操作, 就不可避免地要维护每个节点的 “深度”, 可以采取递归写法。
#include<bits/stdc++.h>
using namespace std;
const int N = 100003;
int n, v[N];
int ct, hd[N], nt[N*2+1], vr[N*2+1];
void ad(int u,int v) {
nt[++ct]=hd[u],hd[u]=ct; vr[ct]=v;
}
int rt[N];
struct Trie{
int ls[N*100], rs[N*100], cov[N*100], tag[N*100], tot;
void ins(int &u,int x,int d) {
u = ++tot;
if(d==-1) {
cov[u]=1; return;
}
if((x>>d)&1) ins(rs[u],x,d-1);
else ins(ls[u],x,d-1);
}
void put(int u,int x,int d) {
if(d==-1) return;
if((x>>d)&1) swap(ls[u],rs[u]);
tag[u] ^= x;
}
void ps_d(int u,int d) {
if(tag[u]) {
if(ls[u]) put(ls[u],tag[u],d-1);
if(rs[u]) put(rs[u],tag[u],d-1);
tag[u] = 0;
}
}
int meg(int u,int v,int d) {
if(u&&v) {
if(d==-1) {
cov[u]|=cov[v]; return u;
}
ps_d(u,d), ps_d(v,d);
ls[u] = meg(ls[u],ls[v],d-1);
rs[u] = meg(rs[u],rs[v],d-1);
cov[u] = (cov[ls[u]] && cov[rs[u]]);
return u;
} else return u|v;
}
int g_mex(int u,int d) {
if(d==-1 || !u) return 0;
if(cov[ls[u]]) return g_mex(rs[u],d-1) ^ (1<<d);
else return g_mex(ls[u],d-1);
}
} T;
int sg[N];
void dfs(int x,int fa) {
int ssg = 0;
for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
if(y^fa) {
dfs(y,x);
ssg ^= sg[y];
}
if(!v[x]) T.ins(rt[x],ssg,17);
for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
if(y^fa) {
T.put(rt[y],ssg^sg[y],17);
rt[x] = T.meg(rt[x],rt[y],17);
}
sg[x] = T.g_mex(rt[x],17);
}
int ans[N], m;
void g_ans(int x,int fa,int ssg) {
for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
if(y^fa) ssg ^= sg[y];
if(!v[x]&&ssg==0) ans[++m] = x;
for(int i=hd[x],y=vr[i]; i; i=nt[i],y=vr[i])
if(y^fa) g_ans(y,x,ssg^sg[y]);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;++i) scanf("%d",&v[i]);
for(int i=1,u,v; i<n; ++i) {
scanf("%d%d",&u,&v);
ad(u,v), ad(v,u);
}
dfs(1,0);
g_ans(1,0,0);
if(m==0) puts("-1");
else
{
sort(ans+1,ans+1+m);
for(int i=1;i<=m;++i) cout << ans[i] << '
';
}
return 0;
}