简单的树形 dp
设 f[u][0]
表示 u
节点没有直接询问的情况下,查明 u
子树的方案数
f[u][1]
表示 u
节点询问过一次的情况, 查明 u
子树的方案数
考虑 f[u][0]
, 因为没有询问, 所以所有子节点都需要知晓
[f[u][0] = prod_{v} (f[v][0] + f[v][1])
]
考虑 f[u][1]
, 则有一个子节点没有询问,可以推出来, 对于该子节点所形成的子树来说,方案数相当于该子节点选了的方案数
[f[u][1] = sum_{v}f[u][0]*(f[v][0]+f[v][1])^{-1}*f[v][1]
]
这是一份错误的代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9 + 7;
const int N = 1e5 + 10;
vector<int> G[N];
using ll = long long;
ll f[N][2];
ll A[N][2], suf[N], pre[N];
int n;
void dfs(int u){
if(G[u].empty()){
f[u][0] = 0; f[u][1] = 1;
return;
}
int cnt = 0;
for(int v : G[u]){
dfs(v);
cnt++;
A[cnt][0] = f[v][0]; A[cnt][1] = f[v][1];
}
f[u][0] = 1; pre[0] = suf[cnt + 1] = 1;
for(int i = 1;i <= cnt;i++) f[u][0] = f[u][0] * (A[i][0] + A[i][1]) % mod;
for(int i = 1;i <= cnt;i++) pre[i] = pre[i-1] * (A[i][0] + A[i][1]) % mod;
for(int i = cnt;i >= 1;i--) suf[i] = suf[i+1] * (A[i][0] + A[i][1]) % mod;
f[u][1] = 0;
for(int i = 1;i <= cnt;i++) {
f[u][1] = (f[u][1] + (A[i][1] * pre[i - 1] % mod * suf[i + 1] % mod)) % mod;
}
}
int main(){
ios::sync_with_stdio(false);cin.tie(nullptr);
cin >> n;
for(int i = 2;i <= n;i++){
int fa;cin >> fa;
fa[G].push_back(i);
}
dfs(1);
cout << (f[1][0] + f[1][1]) % mod << '
';
}
错误的关键在于试图用全局变量把所有子树的结果保存下来, 还是对递归不够透彻
int cnt = 0;
for(int v : G[u]){
dfs(v);
cnt++;
A[cnt][0] = f[v][0]; A[cnt][1] = f[v][1];
}
还是需要开在栈里
#include <bits/stdc++.h>
using namespace std;
const long long mod = 1e9 + 7;
const int N = 1e5 + 10;
vector<int> G[N];
using ll = long long;
ll f[N][2];
ll suf[N], pre[N];
int n;
void dfs(int u){
if(G[u].empty()){
f[u][0] = 0; f[u][1] = 1;
return;
}
int cnt = 0;
vector<ll> A, B;
for(int v : G[u]){
dfs(v);
cnt++;
A.push_back(f[v][0]);
B.push_back(f[v][1]);
}
f[u][0] = 1; pre[0] = suf[cnt + 1] = 1;
for(int i = 1;i <= cnt;i++) f[u][0] = f[u][0] * (A[i-1] + B[i-1]) % mod;
for(int i = 1;i <= cnt;i++) pre[i] = pre[i-1] * (A[i-1] + B[i-1]) % mod;
for(int i = cnt;i >= 1;i--) suf[i] = suf[i+1] * (A[i-1] + B[i-1]) % mod;
f[u][1] = 0;
for(int i = 1;i <= cnt;i++) {
f[u][1] = (f[u][1] + ((B[i-1] * pre[i - 1] % mod * suf[i + 1] % mod))) % mod;
}
}
int main(){
ios::sync_with_stdio(false);cin.tie(nullptr);
cin >> n;
for(int i = 2;i <= n;i++){
int fa;cin >> fa;
fa[G].push_back(i);
}
dfs(1);
cout << (f[1][0] + f[1][1]) % mod << '
';
}