一道很玄妙的题= =
我们考虑先考虑DP 那么有$f[x]=min(c+sum f[y])$ $f[x]$表示覆盖x的子树和x->fa[x]的所有边最小代价 我们枚举一条边c覆盖的x->fa[x]并把它作为主链 f[y]就是除了主链以外的所有点的dp
接着考虑这个玩意怎么维护 我们可以在dp过程中直接把$sum f[y]$放入$c$中 就变成了下面的这些操作
1.将终点在x的链删除。
2.记$sum=sum f[y] y=son[x]$,son[i]子树内所有的链$c+=sum-f[son[i]]$,特别地,起点在i的链$c+=sum$。
3.取出f[x]是子树x中所有的链c的最小值。
显然这个可以数据结构维护掉
接下来我们考虑更为简洁的做法。
我们还是考虑每条向父亲的边都需要被覆盖。所以我们在覆盖x->fa[x]的时候我们是把所有的x的子树的链都合并起来然后选出一条覆盖这个边的。
直接用堆维护,这样的贪心显然是不对的。但是我们考虑用整体标记覆盖的方法。也就是取出堆顶v然后对堆中所有元素打上-v的标记 这样的话就可以选出别的链来替换掉当前的选择。这个方法非常有趣,一会写的另一道题也是用的标记覆盖的方法来维护。
然后我们在每条链的尽头需要把它删掉,实际上也不需要彻底删掉,我们只需要让它不能成为答案即可。这个在取堆顶的时候判断一下就可以了。
这个题很坑的地方就是在pop的时候需要把当前的标记下传掉,然而很多人都没有写这个地方,CF数据也较弱没有卡掉这个问题。在校内OJ上WA到自闭一度以为算法错了的我流下了悲伤的泪水TAT。
//Love and Freedom. #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #define ll long long #define inf 20021225 #define ls(x) t[x].son[0] #define rs(x) t[x].son[1] #define N 300010 using namespace std; int read() { int s=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();} while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar(); return s*f; } struct node{int fa,son[2],dep; ll val,tag;}t[N]; struct edge{int to,lt;}e[N<<1]; int in[N],cnt; ll ans; void add(int x,int y) { e[++cnt].to=y; e[cnt].lt=in[x]; in[x]=cnt; e[++cnt].to=x; e[cnt].lt=in[y]; in[y]=cnt; } void put(int x,ll v){if(!x) return; t[x].tag+=v,t[x].val+=v;} void pushdown(int x) { if(!t[x].tag) return; put(ls(x),t[x].tag); put(rs(x),t[x].tag); t[x].tag=0; } int merge(int x,int y) { if(!x||!y) return x|y; if(t[y].val<t[x].val) swap(x,y); pushdown(x); t[x].son[1]=merge(t[x].son[1],y); t[ls(x)].fa=t[rs(x)].fa=x; t[x].fa=x; if(t[rs(x)].dep>t[ls(x)].dep) swap(ls(x),rs(x)); t[x].dep=t[rs(x)].dep+1; return x; } int rtn[N],top[N]; bool vis[N]; bool GG; void dfs(int x,int f) { for(int i=in[x];i;i=e[i].lt) { int y=e[i].to; if(f==y) continue; dfs(y,x); if(GG) return; rtn[x]=merge(rtn[x],rtn[y]); } vis[x]=1; if(x==1) return; while(vis[top[rtn[x]]]) pushdown(rtn[x]),rtn[x]=merge(ls(rtn[x]),rs(rtn[x])); if(!rtn[x]){GG=1; return;} ans+=t[rtn[x]].val; put(rtn[x],-t[rtn[x]].val); } int main() { int n=read(),m=read(); for(int i=1;i<n;i++){int x=read(),y=read(); add(x,y);} for(int i=1;i<=m;i++) { int x=read(); top[i]=read(); t[i].val=read(); rtn[x]=merge(rtn[x],i); } dfs(1,0); printf("%lld ",GG?-1:ans); return 0; }