转载请注明出处:
http://www.cnblogs.com/hzoi-wangxh/p/7738629.html
Evensgn 剪树枝
时间限制:1s 空间限制:128MB
题目描述
繁华中学有一棵苹果树。苹果树有n 个节点(也就是苹果),n − 1 条边(也就
是树枝)。调皮的Evensgn 爬到苹果树上。他发现这棵苹果树上的苹果有两种:一
种是黑苹果,一种是红苹果。Evensgn想要剪掉 k 条树枝,将整棵树分成k + 1 个
部分。他想要保证每个部分里面有且仅有一个黑苹果。请问他一共有多少种剪树枝
的方案?
输入格式
第一行一个数字n,表示苹果树的节点(苹果)个数。
第二行一共n − 1 个数字p0, p1, p2, p3, ..., pn−2,pi表示第 i + 1 个节点和pi 节
点之间有一条边。注意,点的编号是0 到 n − 1。
第三行一共n 个数字 x0, x1, x2, x3, ..., xn−1。如果xi 是 1,表示i 号节点是黑
苹果;如果xi 是 0,表示i 号节点是红苹果。
输出格式
输出一个数字,表示总方案数。答案对109 + 7 取模。
样例输入1
3
0 0
0 1 1
6
样例输出1
2
样例输入2
6
0 1 1 0 4
1 1 0 0 1 0
样例输出2
1
样例输入3
10
0 1 2 1 4 4 4 0 8
0 0 0 1 0 1 1 0 0 1
样例输出3
27
数据范围
对于30% 的数据,1 ≤n ≤ 10。
对于60% 的数据,1 ≤n ≤ 100。
对于80% 的数据,1 ≤n ≤ 1000。
对于100% 的数据,1 ≤n ≤ 105。
对于所有数据点,都有0 ≤ pi ≤n − 1,xi = 0 或xi = 1。
特别地,60%中、80% 中、100%中各有一个点,树的形态是一条链。
题解:
其实是一个树规。
设f[i][j],f表示方案数,i表示以i为根节点的子树,j为0或1,0表示这棵子树的黑苹果数量等于砍的刀数,1代表砍的刀数比黑苹果数量少1.
为什么设这两种关系?我们可以想一下,整棵树有k个黑苹果,需要砍k-1刀,分成k个部分。如果把其中的一部分单独提出来,发现黑苹果数量比这一段砍的刀数少1,那么其他部分肯定是砍的刀数等于黑苹果数量。
接下来就是状态转移了。我们可以先跑一遍dfs,找出以i节点为根的子树中共有多少个黑苹果。如果为零,那这一段就不用搜了。因为这一棵树中反正也不能砍,对结果没有影响。
首先我们每向上走一层,把少砍一刀的情况加入砍全的情况,f[v][0]+=f[v][1](v为i的合法儿子)。接下来分两种情况,如果i节点为红,f[i][0]=∏f[v][0](v为i的合法儿子)。设sum=∏f[v][0],f[i][1]=Σ(sum/f[v][0]*f[v][1])。如果节点为黑,我们只考虑i的儿子,所以只存在f[i][1]=∏f[v][0].
最后输出f[1][1]。
注意时刻取模,sum/f[v][0]是用逆元。
附上代码
#include<iostream> #include<cstdlib> #include<cstring> #include<cstdio> #include<algorithm> #include<cmath> #include<vector> using namespace std; struct tree{ int u,v,next; }l[301000]; long long f[101000][5],mod=1000000007; int lian[101000],e=0,n,fa[101000],size[101000],a[101000]; void bian(int,int); void dfs(int); void dp(int); long long ksm(long long,long long); int main() { scanf("%d",&n); for(int i=2;i<=n;i++) { int x; scanf("%d",&x); x+=1; bian(x,i); bian(i,x); } for(int i=1;i<=n;i++) { scanf("%d",&a[i]); } dfs(1); dp(1); printf("%lld",f[1][1]); return 0; } void bian(int x,int y) { e++; l[e].u=x; l[e].v=y; l[e].next=lian[x]; lian[x]=e; } void dfs(int x) { if(a[x]!=0) size[x]+=1; for(int i=lian[x];i;i=l[i].next) { int v=l[i].v; if(v!=fa[x]) { fa[v]=x; dfs(v); size[x]+=size[v]; } } } void dp(int x) { int num=0; vector<int> ve; if(a[x]==0) { long long sum=1; for(int i=lian[x];i;i=l[i].next) { int v=l[i].v; if(v==fa[x]) continue; if(size[v]==0) continue; dp(v); num++; ve.push_back(v); f[v][0]+=f[v][1]; sum*=f[v][0]; sum%=mod; } f[x][0]=sum; for(int i=0;i<num;i++) { long long k=sum*ksm(f[ve[i]][0],mod-2)%mod; k*=f[ve[i]][1]; k%=mod; f[x][1]+=k; f[x][1]%=mod; } } else { long long sum=1; for(int i=lian[x];i;i=l[i].next) { int v=l[i].v; if(v==fa[x]) continue; if(size[v]==0) continue; dp(v); num++; ve.push_back(v); f[v][0]+=f[v][1]; sum*=f[v][0]; sum%=mod; } f[x][1]=sum; } } long long ksm(long long x,long long y) { long long ans=1,z=x; while(y) { if((y&1)==1) { ans*=z; ans%=mod; } y=y>>1; z*=z; z%=mod; } return ans; }