题目描述
给定一棵树,设计数据结构支持以下操作 1 u v d 表示将路径 (u,v) 加d 2 u v 表示询问路径 (u,v) 上点权绝对值的和
输入
第一行两个整数n和m,表示结点个数和操作数
接下来一行n个整数a_i,表示点i的权值 接下来n-1行,每行两个整数u,v表示存在一条(u,v)的边 接下来m行,每行一个操作,输入格式见题目描述
输出
对于每个询问输出答案
样例输入
4 4
-4 1 5 -2
1 2
2 3
3 4
2 1 3
1 1 4 3
2 1 3
2 3 4
-4 1 5 -2
1 2
2 3
3 4
2 1 3
1 1 4 3
2 1 3
2 3 4
样例输出
10
13
9
13
9
提示
对于100%的数据,n,m <= 10^5 且 0<= d,|a_i|<= 10^8
如果都是正数直接树链剖分+线段树就行了。
现在有了负数,那不是再维护一个区间正数个数就好了?显然是不够的。
因为区间修改时会把一些负数变为正数,会改变区间正数的个数,所以我们要维护区间三个值:
1、区间绝对值之和
2、区间非负数个数
3、区间最大的负数
当每次修改一个区间时如果这个区间的最大负数会变成非负数,那么说明这个区间的非负数个数会改变,因此要重构这个区间。
怎么重构呢?
对于这个区间的左右子区间,对于不需要重构的子区间下传标记,对于需要重构的子区间就递归重构下去。
因为每个数最多只会被重构一次,因此重构均摊O(nlogn)。总时间复杂度还是O(mlogn)级别。
#include<set> #include<map> #include<stack> #include<queue> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long using namespace std; int num[800010]; int mx[800010]; ll sum[800010]; int d[100010]; int f[100010]; int son[100010]; int size[100010]; int top[100010]; int to[200010]; int tot; int head[100010]; int s[100010]; int q[100010]; int n,m; int x,y,z; int opt; int cnt; ll a[800010]; int next[200010]; int v[100010]; int merge(int x,int y) { if(x<0&&y<0) { return max(x,y); } if(x<0) { return x; } if(y<0) { return y; } return 0; } void add(int x,int y) { tot++; next[tot]=head[x]; head[x]=tot; to[tot]=y; } void dfs(int x) { size[x]=1; d[x]=d[f[x]]+1; for(int i=head[x];i;i=next[i]) { if(to[i]!=f[x]) { f[to[i]]=x; dfs(to[i]); size[x]+=size[to[i]]; if(size[to[i]]>size[son[x]]) { son[x]=to[i]; } } } } void dfs2(int x,int tp) { s[x]=++cnt; top[x]=tp; q[cnt]=x; if(son[x]) { dfs2(son[x],tp); } for(int i=head[x];i;i=next[i]) { if(to[i]!=f[x]&&to[i]!=son[x]) { dfs2(to[i],to[i]); } } } void pushup(int rt) { num[rt]=num[rt<<1]+num[rt<<1|1]; sum[rt]=sum[rt<<1]+sum[rt<<1|1]; mx[rt]=merge(mx[rt<<1],mx[rt<<1|1]); } void pushdown(int rt,bool x,bool y,int l,int r) { if(a[rt]) { int mid=(l+r)>>1; if(x) { if(mx[rt<<1]) { mx[rt<<1]+=a[rt]; } sum[rt<<1]+=1ll*(2*num[rt<<1]-(mid-l+1))*a[rt]; a[rt<<1]+=a[rt]; } if(y) { if(mx[rt<<1|1]) { mx[rt<<1|1]+=a[rt]; } sum[rt<<1|1]+=1ll*(2*num[rt<<1|1]-(r-mid))*a[rt]; a[rt<<1|1]+=a[rt]; } a[rt]=0; } } void build(int rt,int l,int r) { if(l==r) { if(v[q[l]]<0) { mx[rt]=v[q[l]]; } else { num[rt]=1; } sum[rt]=abs(v[q[l]]); return ; } int mid=(l+r)>>1; build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } void rebuild(int rt,int l,int r,ll c) { if(l==r) { num[rt]=1; sum[rt]=mx[rt]+c; mx[rt]=0; return ; } int mid=(l+r)>>1; c+=a[rt]; a[rt]=c; if(mx[rt<<1]&&mx[rt<<1]+c>=0&&mx[rt<<1|1]&&mx[rt<<1|1]+c>=0) { a[rt]=0; rebuild(rt<<1,l,mid,c); rebuild(rt<<1|1,mid+1,r,c); } else if(mx[rt<<1]&&mx[rt<<1]+c>=0) { pushdown(rt,0,1,l,r); rebuild(rt<<1,l,mid,c); } else if(mx[rt<<1|1]&&mx[rt<<1|1]+c>=0) { pushdown(rt,1,0,l,r); rebuild(rt<<1|1,mid+1,r,c); } pushup(rt); } void change(int rt,int l,int r,int L,int R,int c) { if(L<=l&&r<=R) { if(mx[rt]+c>=0&&mx[rt]) { rebuild(rt,l,r,c); } else { if(mx[rt]) { mx[rt]+=c; } a[rt]+=c; sum[rt]+=1ll*(2*num[rt]-(r-l+1))*c; } return ; } int mid=(l+r)>>1; pushdown(rt,1,1,l,r); if(L<=mid) { change(rt<<1,l,mid,L,R,c); } if(R>mid) { change(rt<<1|1,mid+1,r,L,R,c); } pushup(rt); } ll query(int rt,int l,int r,int L,int R) { if(L<=l&&r<=R) { return sum[rt]; } pushdown(rt,1,1,l,r); int mid=(l+r)>>1; long long res=0; if(L<=mid) { res+=query(rt<<1,l,mid,L,R); } if(R>mid) { res+=query(rt<<1|1,mid+1,r,L,R); } return res; } void updata(int x,int y,int z) { while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) { swap(x,y); } change(1,1,n,s[top[x]],s[x],z); x=f[top[x]]; } if(d[x]>d[y]) { swap(x,y); } change(1,1,n,s[x],s[y],z); } ll downdata(int x,int y) { ll res=0; while(top[x]!=top[y]) { if(d[top[x]]<d[top[y]]) { swap(x,y); } res+=query(1,1,n,s[top[x]],s[x]); x=f[top[x]]; } if(d[x]>d[y]) { swap(x,y); } res+=query(1,1,n,s[x],s[y]); return res; } int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) { scanf("%d",&v[i]); } for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1); dfs2(1,1); build(1,1,n); while(m--) { scanf("%d",&opt); scanf("%d%d",&x,&y); if(opt==1) { scanf("%d",&z); updata(x,y,z); } else { printf("%lld ",downdata(x,y)); } } }