题目:https://www.lydsy.com/JudgeOnline/problem.php?id=4712
设 f[x] = min(∑f[u] , a[x]),ls = ∑f[lson]
矩阵是这样的:
ls, a[x]
0, 0
所以假如后面乘一个
f[u], 0
0, 0
就得到了 f[x];
注意,因为定义结构体时把数组都赋成 inf 了,所以后面要用 0 时必须专门赋值成 0;
查询时不是从重链顶开始的,所以不用 get,直接 query(x,ed[top[x]]);
最后 f[x] 和 a[x] 再取一下 min 即可(相当于乘全是0的矩阵)。
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define mid ((l+r)>>1) #define ls (x<<1) #define rs (x<<1|1) using namespace std; typedef long long ll; int const xn=2e5+5; int n,hd[xn],ct,to[xn<<1],nxt[xn<<1],fa[xn],dfn[xn],siz[xn],son[xn]; int id[xn],tim,top[xn],ed[xn]; ll a[xn],f[xn],inf=1e10; ll mnn(ll a,ll b){return a<b?a:b;} struct N{ ll a[2][2]; N(){a[0][0]=a[1][0]=a[0][1]=a[1][1]=inf;} N operator * (const N &y) const { N ret; for(int i=0;i<2;i++) for(int k=0;k<2;k++) for(int j=0;j<2;j++) ret.a[i][j]=mnn(ret.a[i][j],a[i][k]+y.a[k][j]); return ret; } }t[xn<<2],s[xn]; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar(); return f?ret:-ret; } void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;} void dfs(int x,int ff) { fa[x]=ff; siz[x]=1; for(int i=hd[x],u;i;i=nxt[i]) { if((u=to[i])==ff)continue; dfs(u,x); siz[x]+=siz[u]; if(siz[u]>siz[son[x]])son[x]=u; } } void dfs2(int x) { dfn[x]=++tim; id[tim]=x; f[x]=a[x]; ll tmp=0; s[dfn[x]].a[0][0]=inf; s[dfn[x]].a[0][1]=a[x];// s[dfn[x]].a[1][0]=s[dfn[x]].a[1][1]=0;//!!! if(son[x])top[son[x]]=top[x],dfs2(son[x]); else {ed[top[x]]=dfn[x]; return;}//!son: a[0][0]=inf for(int i=hd[x],u;i;i=nxt[i]) if((u=to[i])!=fa[x]&&u!=son[x]) { top[u]=u; dfs2(u); tmp+=f[u]; } s[dfn[x]].a[0][0]=tmp; f[x]=mnn(a[x],tmp+f[son[x]]); } void build(int x,int l,int r) { if(l==r){t[x]=s[l]; return;} build(ls,l,mid); build(rs,mid+1,r); t[x]=t[ls]*t[rs]; } void upt(int x,int l,int r,int pos) { if(l==r){t[x]=s[l]; return;} if(pos<=mid)upt(ls,l,mid,pos); else upt(rs,mid+1,r,pos); t[x]=t[ls]*t[rs]; } N query(int x,int l,int r,int L,int R) { if(l>=L&&r<=R)return t[x]; if(mid>=R)return query(ls,l,mid,L,R); if(mid<L)return query(rs,mid+1,r,L,R); return query(ls,l,mid,L,R)*query(rs,mid+1,r,L,R); } N get(int x){return query(1,1,n,dfn[x],ed[x]);} void chg(int x,int ss) { s[dfn[x]].a[0][1]+=ss;//dfn[x] N pr,nw; while(x) { pr=get(top[x]); upt(1,1,n,dfn[x]); nw=get(top[x]); x=fa[top[x]]; s[dfn[x]].a[0][0]+=mnn(nw.a[0][0],nw.a[0][1])-mnn(pr.a[0][0],pr.a[0][1]);// } } char ch[10]; int main() { n=rd(); for(int i=1;i<=n;i++)a[i]=rd(); for(int i=1,x,y;i<n;i++)x=rd(),y=rd(),add(x,y),add(y,x); dfs(1,0); top[1]=1; dfs2(1); build(1,1,n); int m=rd(); for(int i=1,x,v;i<=m;i++) { scanf("%s",ch); x=rd(); if(ch[0]=='C')v=rd(),chg(x,v); else { N tmp=query(1,1,n,dfn[x],ed[top[x]]);// printf("%lld ",mnn(tmp.a[0][0],tmp.a[0][1])); } } return 0; }