题解
童年的回忆!
想当初,这是我考的第一次省选,我当时初二,我什么都不会,然后看着这个东西,是不是能用我一个月前才会的求lca,光这个lca我就调了一个多小时= =,然后整场五个小时,我觉得其他题不可做,一直杠这题的20分,然后。。。day1爆零了,之后day2手玩提答好像骗了一点,总归是没爆零
那么这个东西是可爱的树dp啦
首先我们考虑m = n(因为大体的代码都需要这个思路吧
我们设(f[u][d])是u这个点不包括u,可以继续往上覆盖d个点最少需要的代价
(g[u][d])是u这个点包括u自己还需要往下覆盖d个点最少需要的代价(和f有点开闭区间的意味哦,但是方便转移,g是闭区间,f是开区间
转移枚举u的儿子
(f[u][d] = min(f[u][d],sum_{v in Son_u} f[v][d + 1] + sum_{t = otherson}g[t][d]))
(g[u][d] = sum_{v in Son_u}g[v][d - 1])
(g[u][ 0 ] = min(g[u][ 0 ],f[u][ 0 ]))
如果选择某个点的话
(f[u][d] = w[u] + sum_{v in Son_u} g[v][d])
我们还有一点,就是在计算完g数组和f数组后,要分别处理成g是前缀最小值,f是后缀最小值
可以这么想(g[u][d - 1])能达到的最小值,(g[u][d])也要能达到啊,显然我多覆盖一点,答案不会受影响
(f[u][d + 1])能达到的最小值,(f[u][d])也要能达到啊,显然我少覆盖一点,答案不会受影响
好的,我们刚才说的只是m = n的情况
如果不是,就证明有的点可以不覆盖,那么这么考虑,就是子树该覆盖的都需要被覆盖,自己不用被覆盖
那么就让(f[u][ 0 ])的初始值为
(f[u][ 0 ] = sum_{v in Son_u}g[v][ 0 ])
最后全部dp完成后答案就是(f[1][ 0 ])
代码
#include <bits/stdc++.h>
#define MAXN 500005
//#define ivorysi
using namespace std;
typedef long long int64;
struct node {
int to,next;
}edge[MAXN * 2];
int head[MAXN],sumedge;
int n,d;
int f[MAXN][21],g[MAXN][21],w[MAXN];
bool appear[MAXN];
void add(int u,int v) {
edge[++sumedge].to = v;
edge[sumedge].next = head[u];
head[u] = sumedge;
}
void addtwo(int u,int v) {
add(u,v);add(v,u);
}
void dp(int u,int fa) {
for(int i = 0 ; i <= d ; ++i) g[u][i] = 0x7fffffff;
for(int i = 0 ; i <= d ; ++i) f[u][i] = 0x7fffffff;
vector<int> s;
for(int i = head[u] ; i ; i = edge[i].next) {
int v = edge[i].to;
if(v != fa) {
dp(v,u);
s.push_back(v);
}
}
if(!appear[u]) {
f[u][0] = 0;
for(auto v : s) {
f[u][0] += g[v][0];
}
}
for(int i = 0 ; i <= d; ++i) {
if(i != d) {
int sum = 0;
for(auto v : s) sum += g[v][i];
for(auto v : s) {
f[u][i] = min(f[u][i],f[v][i + 1] + sum - g[v][i]);
}
}
if(i != 0) {
g[u][i] = 0;
for(auto v : s) g[u][i] += g[v][i - 1];
}
}
f[u][d] = w[u];
for(auto v : s) {
f[u][d] += g[v][d];
}
for(int i = d - 1 ; i >= 0 ; --i) {
f[u][i] = min(f[u][i + 1],f[u][i]);
}
g[u][0] = min(g[u][0],f[u][0]);
for(int i = 1 ; i <= d ; ++i) {
g[u][i] = min(g[u][i],g[u][i - 1]);
}
s.clear();
}
void Init() {
scanf("%d%d",&n,&d);
for(int i = 1 ; i <= n ; ++i) scanf("%d",&w[i]);
int m,u,v;
scanf("%d",&m);
for(int i = 1 ; i <= m ; ++i) {
scanf("%d",&u);
appear[u] = 1;
}
for(int i = 1 ; i < n ; ++i) {
scanf("%d%d",&u,&v);
addtwo(u,v);
}
}
int main() {
#ifdef ivorysi
freopen("f1.in","r",stdin);
#endif
Init();
dp(1,0);
printf("%d
",f[1][0]);
}