题目
描述
给定一棵(n)个点的树和一个(n)元排列(a_{i}) ,(q)个询问,每次询问一个(k),求:
[egin{align} sum_{l=1}^{k}sum_{r=l}^{k} sum_{i=l}^{r}sum_{j=i}^{r} dis(a_{i},a_{j})end{align} \
其中 kle n ,dis(u,v)为u和v的树上最短距离
]
对(998244353)取模的值;
范围
$n,q le 1e5 , u,v,k le n $
题解:
-
考虑每次的增量:
-
对于在(n)之前的区间,增量和上一次的增量相同;
-
对于右端点为(n)的区间,新的增量 = (sum_{i=1}^{n} dis(a_{n},a_{i}) i = sum_{i=1}^{n} (i dep(a_{n}) + idep(a_{i}) -2 idep(lca(a_{i},a_{n})))
-
只需要考虑 (sum_{i=1}^{n} dep(a_{i},a_{n}) imes i)
-
这个直接修改一个点的到根的树链即可,树剖或者(LCT)维护;
#include<bits/stdc++.h> #define ll long long #define mod 998244353 using namespace std; const int N=100010; int n,m,a[N],o=1,hd[N],ch[N][2],fa[N],sum[N],rev[N],ly[N],sz[N],w[N],dep[N],ans[N]; struct Edge{int v,nt;}E[N<<1]; void adde(int u,int v){ E[o]=(Edge){v,hd[u]};hd[u]=o++; E[o]=(Edge){u,hd[v]};hd[v]=o++; } char gc(){ static char*p1,*p2,s[1000000]; if(p1==p2)p2=(p1=s)+fread(s,1,1000000,stdin); return(p1==p2)?EOF:*p1++; } int rd(){ int x=0;char c=gc(); while(c<'0'||c>'9')c=gc(); while(c>='0'&&c<='9')x=(x<<1)+(x<<3)+c-'0',c=gc(); return x; } char ps[1000000],*pp=ps; void push(char x){ if(pp==ps+1000000)fwrite(ps,1,1000000,stdout),pp=ps; *pp++=x; } void write(int x){ static int sta[20],top; if(!x){push('0');push(' ');return;} while(x)sta[++top]=x%10,x/=10; while(top)push(sta[top--]^'0'); push(' '); } void flush(){fwrite(ps,1,pp-ps,stdout);pp=ps;} void pushup(int k){ sum[k]=((ll)sum[ch[k][0]]+sum[ch[k][1]]+w[k])%mod; sz[k]=sz[ch[k][0]]+sz[ch[k][1]]+1; } void pushdown(int k){ int &l=ch[k][0],&r=ch[k][1]; if(rev[k]){ rev[l]^=1,rev[r]^=1; swap(l,r); rev[k]^=1; } if(ly[k]){ int x=ly[k]; sum[l]=(sum[l]+1ll*sz[l]*x%mod)%mod; sum[r]=(sum[r]+1ll*sz[r]*x%mod)%mod; ly[l]+=x;if(ly[l]>=mod)ly[l]-=mod; ly[r]+=x;if(ly[r]>=mod)ly[r]-=mod; w[l]+=x;if(w[l]>=mod)w[l]-=mod; w[r]+=x;if(w[r]>=mod)w[r]-=mod; ly[k]=0; } } bool isrt(int x){return ch[fa[x]][0]!=x&&ch[fa[x]][1]!=x;} void push(int x){ if(!isrt(x))push(fa[x]); pushdown(x); } void rotate(int x){ int y=fa[x],z=fa[y]; if(!isrt(y))ch[z][ch[z][1]==y]=x; int l=ch[y][1]==x,r=l^1; fa[x]=z,fa[y]=x,fa[ch[x][r]]=y; ch[y][l]=ch[x][r],ch[x][r]=y; pushup(y),pushup(x); } void splay(int x){ push(x); for(int y,z;!isrt(x);rotate(x)){ y=fa[x],z=fa[y]; if(!isrt(y))rotate((ch[y][0]==x)^(ch[z][0]==y) ? x : y); } } void access(int x){ for(int y=0;x;y=x,x=fa[x]){ splay(x); ch[x][1]=y; pushup(x); } } void mkrt(int x){access(x);splay(x);rev[x]^=1;} void split(int x,int y){mkrt(x);access(y);splay(y);} void link(int x,int y){mkrt(x),fa[x]=y;} void dfs(int u,int F){ dep[u]=dep[F]+1; for(int i=hd[u];i;i=E[i].nt){ int v=E[i].v; if(v==F)continue; dfs(v,u); link(v,u); } } int main(){ freopen("sumsumsum.in","r",stdin); freopen("sumsumsum.out","w",stdout); n=rd();m=rd(); for(int i=1;i<=n;++i)sz[i]=1; for(int i=1;i<n;++i){ int u=rd(),v=rd(); adde(u,v); } dfs(1,0); for(int i=1,x,y=0,z=0;i<=n;++i){ x=rd(); y+=1ll*i*dep[x]%mod;if(y>=mod)y-=mod; access(x),splay(x); sum[x]=(sum[x]+1ll*sz[x]*i%mod)%mod; ly[x]+=i;if(ly[x]>=mod)ly[x]-=mod; w[x]+=i;if(w[x]>=mod)w[x]-=mod; z=(z + 1ll*i*(i+1)/2%mod*dep[x]%mod + y - 2*sum[x])%mod; if(z<0)z+=mod; ans[i]=(ans[i-1]+z)%mod; } for(int i=1,x;i<=m;++i)/*printf("%d ",ans[rd()]);*/write(ans[rd()]); flush(); return 0; }