大致题意: 有一棵树,树上编号为(i)的节点上有(F_i)个铁球,逃亡者有(V)个磁铁,当他在某个节点放下磁铁时,与这个节点相邻的所有节点上的铁球都会被吸引到这个节点。然后一个追逐者会顺着同样的路去追逐逃亡者。问追逐者遇到的铁球数减去逃亡者遇到的铁球数的最大值。
一个暴力(DP)
我们先来考虑一个暴力的树形(DP)。
不难发现,经过一个节点所能得到的收益应该是它的所有子节点的权值之和,即:$$f_{x,i}=max(f_{fa_x,i},f_{fa_x,i-1}+Size_x)$$
其中(f_{x,i})表示从以某一节点为根节点出发(根节点可以拿来枚举,毕竟这是暴力的做法)到达编号为(x)的节点,放下(i)个磁铁所能到达的最大收益。而(Size_x)则表示(x)的所有子节点的铁球数量之和。
代码如下:
#include<bits/stdc++.h>
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)<(y)?(x):(y))
#define LL long long
#define ull unsigned long long
#define swap(x,y) (x^=y,y^=x,x^=y)
#define tc() (A==B&&(B=(A=ff)+fread(ff,1,100000,stdin),A==B)?EOF:*A++)
#define pc(ch) (pp_<100000?pp[pp_++]=(ch):(fwrite(pp,1,100000,stdout),pp[(pp_=0)++]=(ch)))
#define N 100000
#define M 100
LL pp_=0;char ff[100000],*A=ff,*B=ff,pp[100000];
using namespace std;
LL n,m,ee=0,res,a[N+5],lnk[N+5],fa[N+5],Size[N+5],f[N+5][M+5];
struct edge
{
LL to,nxt;
}e[(N<<1)+5];
inline void read(LL &x)
{
x=0;LL f=1;static char ch;
while(!isdigit(ch=tc())) f=ch^'-'?1:-1;
while(x=(x<<3)+(x<<1)+ch-48,isdigit(ch=tc()));
x*=f;
}
inline void write(LL x)
{
if(x<0) pc('-'),x=-x;
if(x>9) write(x/10);
pc(x%10+'0');
}
inline void add(LL x,LL y)//新加上一条边
{
e[++ee].to=y,e[ee].nxt=lnk[x],lnk[x]=ee,e[++ee].to=x,e[ee].nxt=lnk[y],lnk[y]=ee;
}
inline void GetRt(LL x)//以一个新的节点为根,因此要重新遍历树上的每个节点,并更新每个节点的信息
{
Size[x]=0;//现将当前节点的子节点铁球数量和清零
for(register LL i=lnk[x];i;i=e[i].nxt)//枚举当前节点的每一个子节点
if(e[i].to^fa[x]) fa[e[i].to]=x,GetRt(e[i].to),Size[x]+=a[e[i].to];//更新子节点的信息,并计算出当前节点子节点的铁球数量和
}
inline void DP(LL x)//树形DP,计算出对于每个节点的答案
{
register LL i;
for(res=max(res,f[x][0]=f[fa[x]][0]),i=1;i<=m;++i) f[x][i]=max(f[fa[x]][i],f[fa[x]][i-1]+Size[x]),res=max(res,f[x][i]);//DP转移,并用res记录f数组最大值
for(i=lnk[x];i;i=e[i].nxt) if(e[i].to^fa[x]) DP(e[i].to);//对当前节点的每一个子节点进行DP
}
inline void Clear(int x)//换一个新的根节点,清空数组
{
register LL i,j;
for(fa[x]=res=i=0;i<=n;++i)
for(j=0;j<=m;++j) f[i][j]=0;
}
inline void GetAns(LL x)//求出以x为根节点的答案
{
register LL i;
for(Clear(x),GetRt(x),f[x][1]=Size[x],i=lnk[x];i;i=e[i].nxt) DP(e[i].to);//访问根节点的每一个子节点,对其进行树形DP
}
int main()
{
register LL i,ans=0;
for(read(n),read(m),i=1;i<=n;++i) read(a[i]);
for(i=1;i<n;++i)
{
static LL x,y;
read(x),read(y),add(x,y);
}
for(i=1;i<=n;++i) GetAns(i),ans=max(ans,res);//枚举每一个节点作为根节点,并求出对应的答案
return write(ans),fwrite(pp,1,pp_,stdout),0;
}
考虑优化
不难发现,上面这个(DP)的时间复杂度是(O(n^2m))的(枚举根节点是(O(n))的,(DP)是(O(nm))的),显然会(TLE),因此我们要考虑优化。
应该可以发现,(DP)的时间复杂度是难以优化的,因此我们就要想想看能不能不枚举根节点直接(DP)。
首先,我们以1号节点为根。
然后,我们可以用(Up_{i,j})来表示(i)的子树中的某个节点走到(i),放下(j)个磁铁所能得到的最大收益,并用(Down_{i,j})来表示(i)走到(i)的子树中的某个节点,放下(j)个磁铁所能得到的最大收益。
然后,就可以得出转移方程:
[Up_{x,i}=max(Up_{x,i},max(Up_{son,i},Up_{son,i-1}+sum_x-a_{son}))
]
[Down_{x,i}=max(Down_{x,i},max(Down_{son,i},Down_{son,i-1}+sum_x-a_{fa}))
]
代码如下:
#include<bits/stdc++.h>
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)<(y)?(x):(y))
#define LL unsigned long long
#define N 100000
#define M 100
#define add(x,y) (e[++ee].to=y,e[ee].nxt=lnk[x],lnk[x]=ee,sum[x]+=a[y])
char ff[10000000],*A=ff;
using namespace std;
LL n,m,ans,ee=0,top=0,a[N+5],lnk[N+5],sum[N+5],Up[N+5][M+5],Down[N+5][M+5],Stack[N+5];
struct edge
{
LL to,nxt;
}e[(N<<1)+5];
inline void read(LL &x)
{
x=0;static char ch;
while(!isdigit(ch=*A++));
while(x=(x<<3)+(x<<1)+ch-48,isdigit(ch=*A++));
}
inline void write(LL x)
{
if(x>9) write(x/10);
putchar(x%10+48);
}
inline void DP(LL x,LL son,LL fa)//对节点x进行树形DP,其中son为该节点的一个子节点,fa为该节点的父亲节点
{
register LL i;
for(i=1;i<=m;++i) ans=max(ans,Up[x][i]+Down[son][m-i]);//先记录答案,然后再更新,不然会出现重复计算
for(i=1;i<=m;++i) Up[x][i]=max(Up[x][i],max(Up[son][i],Up[son][i-1]+sum[x]-a[son])),Down[x][i]=max(Down[x][i],max(Down[son][i],Down[son][i-1]+sum[x]-a[fa]));//DP转移
}
inline void dfs(LL x,LL lst)//DFS遍历每一个节点
{
register LL i;
for(i=1;i<=m;++i) Down[x][i]=(Up[x][i]=sum[x])-a[lst];//初始化
for(i=lnk[x];i;i=e[i].nxt) if(e[i].to^lst) dfs(e[i].to,x),DP(x,e[i].to,lst);//对每一个子节点进行DP
for(i=1;i<=m;++i) Down[x][i]=(Up[x][i]=sum[x])-a[lst];//一次DP可能有问题,因此要倒着再DP一遍,所以要重新初始化
for(i=lnk[x];i;i=e[i].nxt) if(e[i].to^lst) Stack[++top]=e[i].to;//将每个节点加入一个栈中
while(top) dp(x,Stack[top--],lst);//倒着再DP一遍
ans=max(ans,max(Up[x][m],Down[x][m]));//更新ans
}
int main()
{
register LL i,x,y;
for(fread(ff,1,10000000,stdin),read(n),read(m),i=1;i<=n;++i) read(a[i]);
for(i=1;i<n;++i) read(x),read(y),add(x,y),add(y,x);
return dfs(1,0),write(ans),0;//这样只要以1号节点为根即可,然后输出ans
}