这道题明显就是告诉你就是树链剖分+线段树维护三次方和,那么显然就是拆项后发现维护一次方和,二次方和和三次方和
这里涉及到两个操作,一个是add一个是mul
因此我们要考虑优先级,这是洛谷的线段树模板2,要先mul再add,因为这样可以解决先加后乘的问题
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=2e5; const int mod=1e9+7; int h[N],ne[N],e[N],idx; int son[N],pre[N],id[N],sz[N],fa[N]; int n; int depth[N],top[N],times; ll w[N]; struct node{ int l,r; ll mul; ll ad; ll sum1; ll sum2; ll sum3; }tr[N<<2]; void add(int a,int b){ e[idx]=b,ne[idx]=h[a],h[a]=idx++; } void dfs(int u){ int i; sz[u]=1; for(i=h[u];i!=-1;i=ne[i]){ int j=e[i]; if(j==fa[u]) continue; fa[j]=u; depth[j]=depth[u]+1; dfs(j); sz[u]+=sz[j]; if(sz[j]>sz[son[u]]){ son[u]=j; } } } void dfs1(int u,int x){ pre[u]=++times; id[times]=u; top[u]=x; if(!son[u]) return; dfs1(son[u],x); int i; for(i=h[u];i!=-1;i=ne[i]){ int j=e[i]; if(j==fa[u]||j==son[u]) continue; dfs1(j,j); } } void pushup(int u){ tr[u].sum1=(tr[u<<1].sum1+tr[u<<1|1].sum1)%mod; tr[u].sum2=(tr[u<<1].sum2+tr[u<<1|1].sum2)%mod; tr[u].sum3=(tr[u<<1].sum3+tr[u<<1|1].sum3)%mod; } void build(int u,int l,int r){ if(l==r){ tr[u]={l,r,1,0,w[id[l]],w[id[l]]*w[id[l]]%mod,w[id[l]]*w[id[l]]%mod*w[id[l]]%mod}; } else{ tr[u]={l,r,1,0,0,0,0}; int mid=l+r>>1; build(u<<1,l,mid); build(u<<1|1,mid+1,r); pushup(u); } } void down(int u,ll x,ll y){ if(y!=1){ tr[u].sum3=(tr[u].sum3*y%mod*y%mod*y)%mod; tr[u].sum2=(tr[u].sum2*y%mod*y)%mod; tr[u].sum1=(tr[u].sum1*y%mod)%mod; tr[u].mul=tr[u].mul*y%mod; tr[u].ad=tr[u].ad*y%mod; } if(x!=0){ tr[u].sum3=(tr[u].sum3+3ll*x*tr[u].sum2+3*x%mod*x%mod*tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod*x%mod*x)%mod; tr[u].sum2=(tr[u].sum2+(tr[u].r-tr[u].l+1)*x%mod*x+2*tr[u].sum1*x)%mod; tr[u].sum1=(tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod)%mod; tr[u].ad=(tr[u].ad+x)%mod; } } void pushdown(int u){ ll y=tr[u].mul,x=tr[u].ad; down(u<<1,x,y); down(u<<1|1,x,y); tr[u].mul=1; tr[u].ad=0; } void modify(int u,int l,int r,ll x,int opt){ if(tr[u].l>=l&&tr[u].r<=r){ if(opt==1){ tr[u].sum1=(tr[u].r-tr[u].l+1)*x%mod; tr[u].sum2=(tr[u].r-tr[u].l+1)*x%mod*x%mod; tr[u].sum3=(tr[u].r-tr[u].l+1)*x%mod*x%mod*x%mod; tr[u].mul=0; tr[u].ad=x; } else if(opt==2){ tr[u].sum3=(tr[u].sum3+3ll*x*tr[u].sum2+3*x%mod*x%mod*tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod*x%mod*x)%mod; tr[u].sum2=(tr[u].sum2+(tr[u].r-tr[u].l+1)*x%mod*x+2*tr[u].sum1*x)%mod; tr[u].sum1=(tr[u].sum1+(tr[u].r-tr[u].l+1)*x%mod)%mod; tr[u].ad=(tr[u].ad+x)%mod; } else if(opt==3){ tr[u].sum3=(tr[u].sum3*x%mod*x%mod*x)%mod; tr[u].sum2=(tr[u].sum2*x%mod*x)%mod; tr[u].sum1=(tr[u].sum1*x%mod)%mod; tr[u].mul=tr[u].mul*x%mod; tr[u].ad=(tr[u].ad*x)%mod; } return ; } pushdown(u); int mid=tr[u].l+tr[u].r>>1; if(l<=mid) modify(u<<1,l,r,x,opt); if(r>mid) modify(u<<1|1,l,r,x,opt); pushup(u); } void change(int x,int y,ll z,int opt){ while(top[x]!=top[y]){ if(depth[top[x]]<depth[top[y]]) swap(x,y); modify(1,pre[top[x]],pre[x],z,opt); x=fa[top[x]]; } if(depth[x]>depth[y]) swap(x,y); modify(1,pre[x],pre[y],z,opt); } ll query(int u,int l,int r){ if(tr[u].l>=l&&tr[u].r<=r){ return tr[u].sum3; } pushdown(u); int mid=tr[u].l+tr[u].r>>1; ll ans=0; if(l<=mid) ans+=query(u<<1,l,r); ans%=mod; if(r>mid) ans=(ans+query(u<<1|1,l,r))%mod; return ans; } ll qpath(int x,int y){ ll res=0; while(top[x]!=top[y]){ if(depth[top[x]]<depth[top[y]]) swap(x,y); res=(res+query(1,pre[top[x]],pre[x]))%mod; x=fa[top[x]]; } if(depth[x]>depth[y]) swap(x,y); res=res+query(1,pre[x],pre[y]); res%=mod; return res; } int main(){ //ios::sync_with_stdio(false); int cas=0; int t; cin>>t; while(t--){ idx=0; scanf("%d",&n); memset(h,-1,sizeof h); memset(sz,0,sizeof sz); memset(son,0,sizeof son); memset(depth,0,sizeof depth); memset(id,0,sizeof id); memset(fa,0,sizeof fa); memset(top,0,sizeof top); times=0; int i; printf("Case #%d: ",++cas); for(i=1;i<n;i++){ int a,b; scanf("%d%d",&a,&b); add(a,b); add(b,a); } for(i=1;i<=n;i++) scanf("%lld",&w[i]); depth[1]=1; fa[1]=0; dfs(1); dfs1(1,1); build(1,1,n); int q; scanf("%d",&q); while(q--){ int opt; scanf("%d",&opt); ll u,v,w; if(opt==1){ scanf("%lld%lld%lld",&u,&v,&w); change(u,v,w,1); } else if(opt==2){ scanf("%lld%lld%lld",&u,&v,&w); change(u,v,w,2); } else if(opt==3){ scanf("%lld%lld%lld",&u,&v,&w); change(u,v,w,3); } else{ scanf("%lld%lld",&u,&v); printf("%lld ",qpath(u,v)%mod); } } } return 0; }