看起来很模板的一个题啊
qwq
但是我还是wei
题目要求的是一个把根节点和所有叶子断开连接的最小花费。
还是想一个比较(naive)的做法
我们令(dp1[i])表示,在(i)的子树内,把叶子全都隔断的最小代价,那么
[dp1[i]=max(sum dp1[p],val[i])
]
但是这样暴力并不能通过这个题。
考虑怎么来优化更新的过程呢。
由于是树上问题,根据套路,我们对原树进行树链剖分。
令(dp[i])表示除去重儿子的所有(dp1[p])的和。
那么我们重新定义矩阵乘法(c[i][j]=max(c[i][j],a[i][k]+b[k][j]))之后,就可以通过矩阵来完成转移了
我们令(g)表示包含重儿子的(ans),然后令(f)表示上述的(dp[i])
令(p)表示重儿子。
那么不难发现
g[p] 0
g[p] 0
和
f[i] +inf
val[i] 0
做矩阵乘法之后,就能得到
g[i] 0
g[i] 0
那我们可以直接用线段树来维护矩阵乘法来进行快速修改和求值了。
但是有一个需要注意的地方就是,对于重链链尾的所有元素的对应转移矩阵要特殊处理,因为他们的(g)是等于(f)的。
那么修改的时候,先进行单点修改(要特判链尾)
然后依次修改每一条重链的(fa)的转移矩阵即可。
qwq一开始有很多地方都没有想明白,就很wei
细节就直接看代码吧
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<cmath>
#include<map>
#include<set>
#define pb push_back
#define mk make_pair
#define ll long long
#define lson ch[x][0]
#define rson ch[x][1]
#define int long long
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)) {if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)) {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 2e5+1e2;
const int maxm = 2*maxn;
const int inf = 1e18;
struct Ju{
int x,y;
int a[3][3];
Ju operator * (Ju b)
{
Ju ans;
ans.x=ans.y=2;
memset(ans.a,0x3f,sizeof(ans.a));
for (int i=1;i<=2;i++)
for(int j=1;j<=2;j++)
for (int k=1;k<=y;k++)
{
ans.a[i][j]=min(ans.a[i][j],a[i][k]+b.a[k][j]);
}
return ans;
}
};
int point[maxn],nxt[maxm],to[maxm],val[maxn];
int cnt,n,m;
Ju pre[maxn];
Ju f[4*maxn];
int top[maxn],newnum[maxn],tail[maxn];
int fa[maxn],son[maxn],size[maxn];
int q;
int back[maxn];
int dp1[maxn],dp[maxn];
int sum[maxn];
//dp[x] doesn't include son[x]
//这个dp数组实质上就是一个sum的形式。
void addedge(int x,int y)
{
nxt[++cnt]=point[x];
to[cnt]=y;
point[x]=cnt;
}
void up(int root)
{
f[root]=f[2*root+1]*f[2*root];
}
void build(int root,int l,int r)
{
if (l==r)
{
int ymh = back[l];
f[root].x=f[root].y=2;
if (tail[top[ymh]]==ymh)
{
f[root].a[1][1]=f[root].a[2][1]=dp1[ymh];
}
else
{
f[root].a[1][1]=dp[ymh];
f[root].a[1][2]=inf;
f[root].a[2][1]=val[ymh];
}
return;
}
int mid = l+r >> 1;
build(2*root,l,mid);
build(2*root+1,mid+1,r);
up(root);
}
void update(int root,int l,int r,int x,Ju p)
{
if(l==r)
{
f[root]=p;
return;
}
int mid = l+r >> 1;
if (x<=mid) update(2*root,l,mid,x,p);
else update(2*root+1,mid+1,r,x,p);
up(root);
}
Ju query(int root,int l,int r,int x,int y)
{
if (x<=l && r<=y)
{
return f[root];
}
int mid = l+r >> 1;
if (x>mid) return query(2*root+1,mid+1,r,x,y);
if (y<=mid) return query(2*root,l,mid,x,y);
return query(2*root+1,mid+1,r,x,y)*query(2*root,l,mid,x,y);
}
void dfs1(int x,int faa)
{
size[x]=1;
int maxson=-1;
for (int i=point[x];i;i=nxt[i])
{
int p = to[i];
if (p==faa) continue;
fa[p]=x;
dfs1(p,x);
size[x]+=size[p];
if (size[p]>maxson)
{
maxson=size[p];
son[x]=p;
}
}
}
int tot;
void dfs2(int x,int chain)
{
top[x]=chain;
tail[chain]=x;
newnum[x]=++tot;
back[tot]=x;
if (!son[x]) return;
dfs2(son[x],chain);
for (int i=point[x];i;i=nxt[i])
{
int p = to[i];
if (!newnum[p]) dfs2(p,p);
}
}
void solve(int x,int fa)
{
for (int i=point[x];i;i=nxt[i])
{
int p = to[i];
if (p==fa) continue;
solve(p,x);
sum[x]+=dp1[p];
}
if (!son[x]) dp1[x]=val[x];
else dp1[x]=min(sum[x],val[x]);
dp[x]=sum[x]-dp1[son[x]];
}
void modify(int x,int y)
{
Ju tmp = query(1,1,n,newnum[x],newnum[x]);
tmp.a[2][1]+=y;
val[x]+=y;
if (tail[top[x]]==x) tmp.a[1][1]+=y;
update(1,1,n,newnum[x],tmp);
for (int now = top[x];now!=1;now=top[now])
{
int faa = fa[now];
Ju ymh = query(1,1,n,newnum[faa],newnum[faa]);
Ju lyf = query(1,1,n,newnum[now],newnum[tail[top[now]]]);
ymh.a[1][1]+=(lyf.a[1][1]-pre[now].a[1][1]);
update(1,1,n,newnum[faa],ymh);
pre[now]=lyf;
now = fa[now];
}
}
signed main()
{
n=read();
for (int i=1;i<=n;i++) val[i]=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read();
addedge(x,y);
addedge(y,x);
}
dfs1(1,0);
dfs2(1,1);
solve(1,0);
build(1,1,n);
for (int i=1;i<=n;i++)
{
pre[i]=query(1,1,n,newnum[i],newnum[tail[top[i]]]);
}
q=read();
for (int i=1;i<=q;i++)
{
char s[10];
scanf("%s",s+1);
if (s[1]=='Q')
{
int x=read();
Ju now = query(1,1,n,newnum[x],newnum[tail[top[x]]]);
cout<<now.a[1][1]<<"
";
}
else
{
int x=read(),y=read();
modify(x,y);
}
}
return 0;
}