树上路径 T20 D71
思路:
((a+b+c)^2=a^2+b^2+c^2+2ab+2ac+2bc)
那么 (ab+bc+ac=((a+b+c)^2-(a^2+b^2+c^2))/2)
线段树维护 (a+b+c)和 (a^2+b^2+c^2)
假设一个含a,b,c的区间+t。
((a+t)^2+(b+t)^2+(c+t)^2=a^2+b^2+c^2+3t^2+2t(a+b+c))
即 (a^2+b^2+c^2)的值为它本身的值+区间长度乘t的平方+2t(a+b+c)的值
树剖后线段树维护即可。
#include<bits/stdc++.h>
#define ll long long
#define pii pair<long long,long long>
#define fi first
#define se second
#define pb push_back
#define si size()
#define ls (p<<1)
#define rs ((p<<1)|1)
#define mid (t[p].l+t[p].r)/2
using namespace std;
ll read(){ll x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}return x*f;}
inline void Prin(ll x){if(x < 0){putchar('-');x = -x;}if(x > 9) Prin(x / 10);putchar(x % 10 + '0');}
const int qs=1e5+7;
const int mod=1e9+7;
ll n,m,q,a[qs];
int dep[qs],f[qs],fr[qs],sz[qs],son[qs],top[qs],cnt=0;
vector<int> v[qs];
pii id[qs];
ll qpow(ll a,ll b){
ll ret=1;
while(b){
if(b&1) ret=ret*a%mod;
a=a*a%mod;
b>>=1;
}
return ret;
}
struct Tree{
ll val,sum,add;
int l,r;
#define l(x) t[x].l
#define r(x) t[x].r
#define val(x) t[x].val
#define add(x) t[x].add
#define sum(x) t[x].sum
}t[qs<<2];
void pushup(int p){ val(p)=(val(ls)+val(rs))%mod;sum(p)=(sum(ls)+sum(rs))%mod;}
void down(int p){
if(!add(p)) return;
sum(ls)=(sum(ls)+(r(ls)-l(ls)+1)*add(p)%mod*add(p)%mod+2*add(p)%mod*val(ls));
sum(rs)=(sum(rs)+(r(rs)-l(rs)+1)*add(p)%mod*add(p)%mod+2*add(p)%mod*val(rs));
val(ls)=(val(ls)+add(p)*(r(ls)-l(ls)+1))%mod;
val(rs)=(val(rs)+add(p)*(r(rs)-l(rs)+1))%mod;
add(ls)=(add(p)+add(ls))%mod;
add(rs)=(add(p)+add(rs))%mod;
add(p)=0;
}
void build(int p,int l,int r){
l(p)=l,r(p)=r;add(p)=0;
if(l==r){
val(p)=a[fr[l]]%mod;
sum(p)=val(p)*val(p)%mod;
return;
}
build(ls,l,mid);
build(rs,mid+1,r);
pushup(p);
}
void update(int p,int l,int r,ll val){
if(l<=l(p)&&r>=r(p)){
sum(p)=(sum(p)+(r(p)-l(p)+1)*val%mod*val%mod+2*val%mod*val(p));
val(p)=(val(p)+val*(r(p)-l(p)+1))%mod;
add(p)=(val+add(p))%mod;
return;
}
down(p);
if(l<=mid) update(ls,l,r,val);
if(r>mid) update(rs,l,r,val);
pushup(p);
}
ll fm;
pii ask(int p,int l,int r){
if(l<=l(p)&&r>=r(p)) {
pii res={val(p),sum(p)};
return res;
}
down(p);
pii val={0,0};
if(l<=mid) {
pii ft=ask(ls,l,r);
val.fi=(val.fi+ft.fi)%mod;
val.se=(val.se+ft.se)%mod;
}
if(r>mid){
pii ft=ask(rs,l,r);
val.fi=(val.fi+ft.fi)%mod;
val.se=(val.se+ft.se)%mod;
}
return val;
}
void dfs(int x,int fa){
dep[x]=dep[fa]+1; f[x]=fa; sz[x]=1; son[x]=0;
int ms=0;
for(int i=0;i<v[x].si;++i){
int p=v[x][i];
if(p==fa) continue;
dfs(p,x);
sz[x]+=sz[p];
if(sz[p]>ms) son[x]=p,ms=sz[p];
}
}
void dfn(int x,int po){
id[x].fi=++cnt;
fr[cnt]=x;
top[x]=po;
if(!son[x]) {
id[x].se=cnt; return;
}
dfn(son[x],po);
for(int i=0;i<v[x].si;++i){
int p=v[x][i];
if(p==f[x]||p==son[x]) continue;
dfn(p,p);
}
id[x].se=cnt;
}
void updRange(int x,int y,ll k){
k%=mod;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,id[top[x]].fi,id[x].fi,k);
x=f[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,id[x].fi,id[y].fi,k);
}
ll qRange(int x,int y){
ll ans=0;
pii res={0,0},ft;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ft=ask(1,id[top[x]].fi,id[x].fi);
res.fi=(res.fi+ft.fi)%mod;
res.se=(res.se+ft.se)%mod;
x=f[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ft=ask(1,id[x].fi,id[y].fi);
res.fi=(res.fi+ft.fi)%mod;
res.se=(res.se+ft.se)%mod;
ans=(res.fi*res.fi%mod-res.se+mod)%mod*fm%mod;
return ans;
}
int main(){
fm= qpow(2,mod-2);
n=read(),m=read();
ll x,y,op;
for(int i=1;i<=n;++i){
a[i]=read();
}
for(int i=1;i<n;++i){
x=read(),y=read();
v[x].pb(y);
v[y].pb(x);
}
dfs(1,0);
dfn(1,1);
build(1,1,n);
while(m--){
op=read(),x=read(),y=read();
if(op==1){
update(1,id[x].fi,id[x].se,y);
}
else if(op==2){
ll z; z=read();
updRange(x,y,z);
}
else{
cout<<qRange(x,y)<<"
";
}
}
return 0;
}
/*
*/