题解:
先考虑在在区间上如何做这个操作。考虑两个相邻的区间A,B,不妨设区间A在区间B的左端,设区间A的颜色段数量为 $sum_A$ 区间B的颜色段数量为 $sum_B$,那么将区间A和B合并后颜色段的数量是否是 $sum_A+sum_B$ 呢?显然不是,如果区间A的右端和区间B的左端颜色相同的话,答案应该是 $sum_A+sum_B -1$,画个图很好理解。合并后的大区间的左端颜色显然和区间A的左端颜色相同,大区间的右端颜色和B的右端颜色相同。那么显然合并这个过程是满足“区间可加性”的,于是我们可以用线段树来维护某一段区间的左端颜色、右端颜色与颜色段。
那么问题来了,现在并不是在区间上,而是在树上!这时候就是我们的“树链剖分”大展身手的时刻了!树链剖分能将树分成一条一条链,且链上编号是连续的,很方便我们用线段树来维护。
附上代码(注释很详细哦):
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int N=200000+5; 4 5 //struct数组存线段树,lson与rson代表左右区间,lco与rco代表左右两端的颜色 6 //sum代表区间颜色块的数量 7 struct SegementTree{ 8 int lson,rson,lco,rco,sum; 9 }t[4*N]; 10 11 int c[N],size[N],fa[N],son[N],d[N],top[N],id[N],color[N],cm[N]; 12 int ver[2*N],nxt[2*N],head[N]; 13 int n,m,tot,cnt; 14 char or1; 15 16 int read(){ 17 int x=0,f=1; 18 char ch=getchar(); 19 while(ch<'0' ||ch>'9'){if(ch=='-') f=-1;ch=getchar();} 20 while(ch>='0' && ch<='9') {x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} 21 return x*f; 22 } 23 24 void add(int x,int y){ver[++tot]=y,nxt[tot]=head[x],head[x]=tot;} 25 26 void dfs1(int x){ 27 size[x]=1;d[x]=d[fa[x]]+1; 28 for(int i=head[x];i;i=nxt[i]){ 29 int y=ver[i]; 30 if(y==fa[x]) continue; 31 fa[y]=x; 32 dfs1(y); 33 size[x]+=size[y]; 34 if(size[y]>size[son[x]]) son[x]=y; 35 } 36 } 37 38 void dfs2(int x,int tp){ 39 top[x]=tp; 40 //给树上节点按照“重儿子”先遍历的顺序重新编号,c[x]保存的是节点x的颜色,因为要建线段树,所以用color数组按照编号来保存颜色 41 //执行一个类似"离散化"的过程 42 id[x]=++cnt,color[cnt]=c[x]; 43 if(!son[x]) return; 44 dfs2(son[x],tp); 45 for(int i=head[x];i;i=nxt[i]){ 46 int y=ver[i]; 47 if(y==son[x] || y==fa[x]) continue; 48 dfs2(y,y); 49 } 50 } 51 52 //正常建线段树就好了 53 void build(int p,int l,int r){ 54 t[p].lson=l,t[p].rson=r; 55 if(l==r){ 56 t[p].lco=t[p].rco=color[l];//将颜色按照编号对应过来 57 t[p].sum=1; 58 return; 59 } 60 int mid=(l+r)>>1; 61 build(p<<1,l,mid); 62 build(p<<1|1,mid+1,r); 63 //继承左右区间的左右两端颜色 64 t[p].lco=t[p<<1].lco,t[p].rco=t[p<<1|1].rco; 65 t[p].sum=t[p<<1].sum+t[p<<1|1].sum; 66 //如果左右区间“连接处”颜色相同的话,颜色块数量要减1 67 if(t[p<<1].rco==t[p<<1|1].lco) t[p].sum--; 68 } 69 70 void spread(int p){//下传延迟标记 71 if(cm[p]){ 72 t[p<<1].lco=t[p<<1].rco=cm[p]; 73 t[p<<1|1].lco=t[p<<1|1].rco=cm[p]; 74 t[p<<1].sum=t[p<<1|1].sum=1; 75 //因为颜色是覆盖的所以不用管以前的标记 76 cm[p<<1]=cm[p<<1|1]=cm[p]; 77 cm[p]=0; 78 } 79 } 80 81 void changetree(int p,int l,int r,int c){ 82 if(l<=t[p].lson && r>=t[p].rson){//正常修改就行了 83 t[p].lco=t[p].rco=c; 84 t[p].sum=1; 85 cm[p]=c; 86 return; 87 } 88 spread(p); 89 int mid=(t[p].lson+t[p].rson)>>1; 90 if(l<=mid) changetree(p<<1,l,r,c); 91 if(r>mid) changetree(p<<1|1,l,r,c); 92 t[p].lco=t[p<<1].lco,t[p].rco=t[p<<1|1].rco; 93 t[p].sum=t[p<<1].sum+t[p<<1|1].sum; 94 if(t[p<<1].rco==t[p<<1|1].lco) t[p].sum--; 95 } 96 97 void change(int a,int b,int c){ 98 while(top[a]!=top[b]){ 99 if(d[top[a]]<d[top[b]]) swap(a,b); 100 changetree(1,id[top[a]],id[a],c);//改变树链上的颜色 101 a=fa[top[a]]; 102 } 103 if(d[a]>d[b]) swap(a,b); 104 changetree(1,id[a],id[b],c); 105 } 106 107 int enqurytree(int p,int l,int r){//正常在区间上查询就行了 108 if(l<=t[p].lson && r>=t[p].rson){return t[p].sum;} 109 spread(p); 110 int mid=(t[p].lson+t[p].rson)>>1,ans=0; 111 if(l<=mid && r>mid){ 112 ans=enqurytree(p<<1,l,r)+enqurytree(p<<1|1,l,r); 113 if(t[p<<1].rco==t[p<<1|1].lco) ans--; 114 return ans; 115 }else if(l>mid && r>mid){ 116 ans=enqurytree(p<<1|1,l,r); 117 return ans; 118 }else if(l<=mid && r<=mid){ 119 ans=enqurytree(p<<1,l,r); 120 return ans; 121 } 122 } 123 124 int enqurytree1(int p,int x){//用来单点查询某个节点的颜色。。其实也可以用上面那个函数 125 if(t[p].lson==x && t[p].rson==x){return t[p].lco;} 126 spread(p); 127 int mid=(t[p].lson+t[p].rson)>>1; 128 if(x<=mid) return enqurytree1(p<<1,x); 129 if(x>mid) return enqurytree1(p<<1|1,x); 130 } 131 132 int enqury(int a,int b){ 133 int ans=0; 134 while(top[a]!=top[b]){ 135 if(d[top[a]]<d[top[b]]) swap(a,b); 136 ans+=enqurytree(1,id[top[a]],id[a]); 137 //查询“连接处”的颜色,如果颜色相同的话那么颜色块数量要建1 138 if(enqurytree1(1,id[top[a]])==enqurytree1(1,id[fa[top[a]]])) ans--; 139 a=fa[top[a]]; 140 } 141 if(d[a]>d[b]) swap(a,b); 142 //前面要减的已经减过了所以这里不用减 143 ans+=enqurytree(1,id[a],id[b]); 144 return ans; 145 } 146 147 int main(){ 148 n=read(),m=read(); 149 for(int i=1;i<=n;++i) c[i]=read(); 150 for(int i=1,x,y;i<n;++i){ 151 x=read(),y=read(); 152 add(x,y),add(y,x); 153 } 154 //处理出树链 155 dfs1(1); 156 dfs2(1,1); 157 //梦开始的地方 158 build(1,1,n); 159 while(m--){ 160 scanf("%s",&or1); 161 if(or1=='C'){ 162 int a,b,c; 163 a=read(),b=read(),c=read(); 164 change(a,b,c); 165 } 166 if(or1=='Q'){ 167 int a,b; 168 a=read(),b=read(); 169 printf("%d ",enqury(a,b)); 170 } 171 } 172 return 0; 173 }