这题中间转移一直没想明白,一直卡壳。
状态dp[i][j] - i向下最深距离为j的方案数。
状态设置对了,但是发现,子树一直转移维护的时候,枚举i,j一直当成了i + j + 1的深度了。
但是这个其实是距离,我们需要维护的应该是max(i,j + 1)的深度。这样才能转移。
但是这样还是不好考虑,因为这个点u如果是关键点的话,就会非常特殊。
但是我们可以发现,u的子树组成的方案一定会经过u,那么我们可以先不管u的方案,最后再所有方案都 * 2即可。
这里转移有两种形式:
1 :子树一直转移。
2 :子树和u组成的贡献,这里也是很关键的一步。
首先这里我们可以枚举i,j去转移,但是这样显然是n ^ 2会T,我们可以考虑维护一下每个点能往下走的最远距离,降低一下枚举的上限。
这样就不会T了,但是观察第二维就可以发现,我们可以做一个前缀和来实现O(n)的转移。
然后就是很关键的地方:由于这题数据太弱了,很多假代码都过了,包括我一开始写的也是。
看下面这组数据。
4 4 1
1 2
1 3
2 4
1 2 3 4
输出:7
看了网上大部分的题解代码都是5,和我一开始一样。
假代码处理答案:
LL ans = 0; for(int i = 0;i <= k;++i) ans = ADD(ans,dp[1][i]);
这是因为没有算上子树中满足条件的答案。
因为我们这样算,必定就要经过1.但是不经过1满足的方案数都没算入。
考虑维护一个超过k的方案数。
如果子树中这个距离的方案数 != 0,那么就直接转移上来。说明他是以某个子树为中间点的合法方案数。
// Author: levil #include<bits/stdc++.h> using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef pair<int,int> pii; typedef tuple<int,int,int> tu; const int N = 5e3 + 5; const int M = 1e3 + 5; const double eps = 1e-10; const LL Mod = 1e9 + 7; #define pi acos(-1) #define INF 1e9 #define dbg(ax) cout << "now this num is " << ax << endl; inline int read() { int f = 1;int x = 0;char c = getchar(); while(c < '0' || c > '9') {if(c == '-') f = -1;c = getchar();} while(c >= '0' && c <= '9'){x = (x<<1)+(x<<3)+(c^48);c = getchar();} return x*f; } inline long long ADD(long long x,long long y) {return (x + y) % Mod;} inline long long DEC(long long x,long long y) {return (x - y + Mod) % Mod;} inline long long MUL(long long x,long long y) {return x * y % Mod;} int n,m,k; vector<int> G[N]; bool vis[N]; LL dp[N][N];//dp[i][j] - 距离i最大距离为j的方案数 LL f[N]; int mxd[N]; void dfs(int u,int fa) { for(auto v : G[u]) { if(v == fa) continue; dfs(v,u); memset(f,0,sizeof(f)); for(int i = 0;i <= min(k,mxd[u]);++i) { for(int j = 0;j <= min(k,mxd[v]);++j) { if(i + j + 1 <= k) f[max(i,j + 1)] = ADD(f[max(i,j + 1)],MUL(dp[u][i],dp[v][j])); } } mxd[u] = max(mxd[u],mxd[v] + 1); for(int i = 1;i <= mxd[u];++i) dp[u][i] = ADD(dp[u][i],dp[v][i - 1]); for(int i = 0;i <= mxd[u];++i) dp[u][i] = ADD(dp[u][i],f[i]); } if(vis[u] == 1) { for(int i = 0;i <= k;++i) dp[u][i] = MUL(2,dp[u][i]); dp[u][0]++; } //for(int i = 0;i <= min(mxd[u],k);++i) printf("dp[%d][%d] is %lld ",u,i,dp[u][i]); } void solve() { n = read(),m = read(),k = read(); for(int i = 1;i < n;++i) { int u,v;u = read(),v = read(); G[u].push_back(v); G[v].push_back(u); } for(int i = 1;i <= m;++i) { int x = read(); vis[x] = 1; } dfs(1,0); LL ans = 0; for(int i = 0;i <= mxd[1];++i) ans = ADD(ans,dp[1][i]); printf("%lld ",ans); } int main() { solve(); system("pause"); return 0; } /* 4 4 2 1 2 1 3 2 4 1 2 3 4 4 4 1 1 2 1 3 2 4 1 2 3 4 */