这题的区间维护比较麻烦,顺便复习了一下区间合并
维护区间间隔色段数,跨链时更新一下上一条链顶的颜色,去重
#include<bits/stdc++.h> //#pragma comment(linker, "/STACK:1024000000,1024000000") #include<stdio.h> #include<algorithm> #include<queue> #include<string.h> #include<iostream> #include<math.h> #include<set> #include<map> #include<vector> #include<iomanip> using namespace std; #define ll long long #define pb push_back #define FOR(a) for(int i=1;i<=a;i++) const int inf=0x3f3f3f3f; const int maxn=1e5+9; int arr[maxn]; int n,q; struct EDGE{ int v;int next; }G[maxn<<1]; int head[maxn],tot; void addedge(int u,int v){ ++tot;G[tot].v=v;G[tot].next=head[u];head[u]=tot; } int top[maxn]; int pre[maxn]; int dep[maxn]; int num[maxn]; int p[maxn]; //v的对应位置 int out[maxn]; //退出时间戳 int fp[maxn]; //访问序列 int son[maxn]; //重儿子 int pos; void init(){ memset(head,-1,sizeof head);tot=0; memset(son,-1,sizeof son); } void dfs1(int u,int fa,int d){ dep[u]=d; pre[u]=fa; num[u]=1; for(int i=head[u];~i;i=G[i].next){ int v=G[i].v; if(v==fa)continue; dfs1(v,u,d+1); num[u]+=num[v]; if(son[u]==-1||num[v]>num[son[u]])son[u]=v; } } void getpos(int u,int sp){ top[u]=sp; p[u]=out[u]=++pos; fp[p[u]]=u; if(son[u]==-1)return; getpos(son[u],sp); for(int i=head[u];~i;i=G[i].next){ int v=G[i].v; if(v!=son[u]&&v!=pre[u])getpos(v,v); } out[u]=pos; } struct NODE{ int lcol,rcol,sum,lazy; }ST[maxn<<2]; void pushup(int rt){ ST[rt].sum=ST[rt<<1].sum+ST[rt<<1|1].sum; if(ST[rt<<1].rcol==ST[rt<<1|1].lcol)ST[rt].sum--; ST[rt].lcol=ST[rt<<1].lcol;ST[rt].rcol=ST[rt<<1|1].rcol; } void pushdown(int rt){ if(!ST[rt].lazy)return; ST[rt<<1].lcol=ST[rt<<1|1].lcol=ST[rt<<1].rcol=ST[rt<<1|1].rcol= ST[rt].lcol; ST[rt<<1].sum=ST[rt<<1|1].sum=1; ST[rt<<1].lazy=ST[rt<<1|1].lazy=1; ST[rt].lazy=0; } void build(int l,int r,int rt){ if(l==r){ST[rt].sum=1;ST[rt].lcol=ST[rt].rcol=arr[fp[l]];return;} int m=l+r>>1;build(l,m,rt<<1);build(m+1,r,rt<<1|1);pushup(rt); } void update(int a,int b,int c,int l,int r,int rt){ if(a<=l&&b>=r){ ST[rt].sum=1; ST[rt].lcol=ST[rt].rcol=c; ST[rt].lazy=1; return; } pushdown(rt); int m=l+r>>1; if(a<=m)update(a,b,c,l,m,rt<<1); if(b>m)update(a,b,c,m+1,r,rt<<1|1); pushup(rt); } int L,R; int query(int a,int b,int l,int r,int rt){ if(a==l)L=ST[rt].lcol; if(b==r)R=ST[rt].rcol; if(a<=l&&b>=r)return ST[rt].sum; pushdown(rt); int m=l+r>>1; int ans=0; if(b<=m)return query(a,b,l,m,rt<<1); else if(a>m)return query(a,b,m+1,r,rt<<1|1); if(ST[rt<<1].rcol==ST[rt<<1|1].lcol)ans--; ans+=query(a,b,l,m,rt<<1);ans+=query(a,b,m+1,r,rt<<1|1); return ans; } void solve1(int x,int y,int z){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]])swap(x,y); update(p[top[x]],p[x],z,1,n,1); x=pre[top[x]]; } if(dep[x]<dep[y])swap(x,y); update(p[y],p[x],z,1,n,1); } void solve2(int x,int y){ int ans=0,ans1=-1,ans2=-1;//上次链的左端颜色 while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]){swap(x,y);swap(ans1,ans2);} ans+=query(p[top[x]],p[x],1,n,1); //cout<<L<<" "<<R<<endl; if(R==ans1)ans--; ans1=L;x=pre[top[x]]; } if(dep[x]<dep[y]){swap(x,y);swap(ans1,ans2);} ans+=query(p[y],p[x],1,n,1); if(R==ans1)ans--;if(L==ans2)ans--; printf("%d ",ans); } char op[5]; int main(){ scanf("%d%d",&n,&q); init(); for(int i=1;i<=n;i++){scanf("%d",&arr[i]);} for(int i=1,x,y;i<n;i++){ scanf("%d%d",&x,&y);addedge(x,y);addedge(y,x); } dfs1(1,1,0);getpos(1,1); int x,y,z; build(1,n,1); while(q--){ scanf("%s",op); if(op[0]=='Q'){ scanf("%d%d",&x,&y); solve2(x,y); }else{ scanf("%d%d%d",&x,&y,&z); solve1(x,y,z); } } }