问题:有n个人,最多选k个,如果选了某个人就必须选他指定的另一个人,问最多能选多少个人。
将每个人所指定的人向他连一条单向边,则每一个点都有唯一的前驱,形成的图是个基环树森林,在同一个强连通分量里的点要么全选,要么全不选。
首先用Tarjan算法将每个强连通分量(基环树上的环)缩成一个点,这样每棵基环树就变成了普通的树了。
定义每颗树上没有入度的点为树根,建立一个虚根与每棵树的根连一条边,将森林转化成树,对根节点求一遍树形背包即可。
树形依赖背包是树形背包的一个特例,即树形背包在根节点上的dp值。
可用siz数组或者bitset优化。
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1000+10; 5 int hd[N],op[N],ne,n,k,dp[N][N],dg[N],siz[N],mx[N],dfn[N],low[N],scc[N],sta[N],tot,nscc,tp; 6 struct E {int v,nxt;} e[N<<1]; 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++,dg[v]++;} 8 void Tarjan(int u) { 9 low[u]=dfn[u]=++tot; 10 sta[++tp]=u; 11 int v=op[u]; 12 if(!dfn[v])Tarjan(v),low[u]=min(low[u],low[v]); 13 else if(!scc[v])low[u]=min(low[u],dfn[v]); 14 if(low[u]==dfn[u])for(nscc++; !scc[u]; scc[sta[tp--]]=nscc); 15 } 16 void getscc() { 17 memset(scc,0,sizeof scc); 18 memset(dfn,0,sizeof dfn); 19 nscc=tot=0,tp=-1; 20 for(int i=1; i<=n; ++i)if(!dfn[i])Tarjan(i); 21 memset(siz,0,sizeof siz); 22 memset(dg,0,sizeof dg); 23 for(int i=1; i<=n; ++i)siz[scc[i]]++; 24 for(int u=1; u<=n; ++u) { 25 int v=op[u]; 26 if(scc[v]!=scc[u])addedge(scc[v],scc[u]); 27 } 28 for(int i=1; i<=nscc; ++i)if(!dg[i])addedge(0,i); 29 } 30 void dfs(int u) { 31 memset(dp[u],0,sizeof dp[u]); 32 dp[u][siz[u]]=1; 33 for(int i=hd[u]; ~i; i=e[i].nxt) { 34 int v=e[i].v; 35 dfs(v); 36 for(int j=siz[u]; j>=0; --j)if(dp[u][j]) 37 for(int k=0; k<=siz[v]; ++k)if(dp[v][k]) 38 dp[u][j+k]=1; 39 siz[u]+=siz[v]; 40 } 41 } 42 43 int main() { 44 memset(hd,-1,sizeof hd),ne=0; 45 scanf("%d%d",&n,&k); 46 for(int i=1; i<=n; ++i)scanf("%d",&op[i]); 47 getscc(); 48 dfs(0); 49 for(int i=k; i>=0; --i)if(dp[0][i]) {printf("%d ",i); break;} 50 return 0; 51 }
bitset优化版:
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int N=1000+10; 5 int hd[N],op[N],ne,n,k,dg[N],siz[N],dfn[N],low[N],scc[N],sta[N],tot,nscc,tp; 6 bitset<N> dp[N]; 7 struct E {int v,nxt;} e[N]; 8 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++,dg[v]++;} 9 void Tarjan(int u) { 10 low[u]=dfn[u]=++tot; 11 sta[++tp]=u; 12 int v=op[u]; 13 if(!dfn[v])Tarjan(v),low[u]=min(low[u],low[v]); 14 else if(!scc[v])low[u]=min(low[u],dfn[v]); 15 if(low[u]==dfn[u])for(nscc++; !scc[u]; scc[sta[tp--]]=nscc); 16 } 17 void getscc() { 18 memset(scc,0,sizeof scc); 19 memset(dfn,0,sizeof dfn); 20 nscc=tot=0,tp=-1; 21 for(int i=1; i<=n; ++i)if(!dfn[i])Tarjan(i); 22 memset(siz,0,sizeof siz); 23 memset(dg,0,sizeof dg); 24 for(int i=1; i<=n; ++i)siz[scc[i]]++; 25 for(int u=1; u<=n; ++u) { 26 int v=op[u]; 27 if(scc[v]!=scc[u])addedge(scc[v],scc[u]); 28 } 29 for(int i=1; i<=nscc; ++i)if(!dg[i])addedge(0,i); 30 } 31 void dfs(int u) { 32 dp[u].reset(); 33 dp[u].set(siz[u]); 34 for(int i=hd[u]; ~i; i=e[i].nxt) { 35 int v=e[i].v; 36 dfs(v); 37 bitset<N> t=dp[u]; 38 for(int j=0; j<N; ++j)if(dp[v].test(j))dp[u]|=t<<j; 39 } 40 } 41 42 int main() { 43 memset(hd,-1,sizeof hd),ne=0; 44 scanf("%d%d",&n,&k); 45 for(int i=1; i<=n; ++i)scanf("%d",&op[i]); 46 getscc(); 47 dfs(0); 48 for(int i=k; i>=0; --i)if(dp[0].test(i)) {printf("%d ",i); break;} 49 return 0; 50 }