题目描述 Description
|
给定一棵有n个节点的无根树和m个操作,操作有2类: 1、将节点a到节点b路径上所有点都染成颜色c; 2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。 请你写一个程序依次完成这m个操作。 |
输入描述 Input Description
|
第一行包含2个整数n和m,分别表示节点数和操作数; 第二行包含n个正整数表示n个节点的初始颜色 下面n-1行每行包含两个整数 和 ,表示x和y之间有一条无向边。 下面m行每行描述一个操作: “C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c; “Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。 |
输出描述 Output Description
|
对于每个询问操作,输出一行答案。
|
样例输入 Sample Input
|
6 5 2 2 1 2 1 1 1 2 1 3 2 4 2 5 2 6 Q 3 5 C 2 1 1 Q 3 5 C 5 1 2 Q 3 5 |
样例输出 Sample Output
|
3 1 2 |
数据范围及提示 Data Size & Hint
|
对于100%的数据1≤n≤10^5,1≤m≤10^5,1≤c≤10^9;
|
很明显是一道树链剖分的题,所以我们只需要考虑线性时候该怎么搞了。
用线段树是肯定的。对于每一个节点o,维护该区间内颜色块数量是肯定的,用cnum[o]来表示。但是发现不能强行合并,因为相邻部分有可能颜色相同,算成一块。所以我们再需要维护两个信息L[o]与R[o]分别记录区间o最左端的颜色以及最右边颜色就可以了。在pushup过程中需要判断如果r[lo]==l[ro],cnum[o]--;这样的话线性做法就可以了。
写好线段树之后,默一份树链剖分代码上去,套一个线段树,结果发现连样例都过不了!手画一下样例,发现在树链上进行操作时进行了若干次query,这些query中的相邻部分有可能颜色相同,怎么办呢,不要慌,我们再写一个名叫special的函数,专门来计算每一条重链上最上段点的颜色,在合并的时候特判一下就好了。
#include<iostream> #include<algorithm> #include<cstdio> #include<queue> #include<cmath> #include<cstring> using namespace std; typedef long long LL; #define mem(a,b) memset(a,b,sizeof(a)) inline int read() { int x=0,f=1;char c=getchar(); while(!isdigit(c)){if(c=='-')f=-1;c=getchar();} while(isdigit(c)){x=x*10+c-'0';c=getchar();} return x*f; } const int maxn=100010; int n,m,ce=-1,es,a,b,c,first[maxn],size[maxn],fa[maxn],deep[maxn],w[maxn],id[maxn],bl[maxn]; char tp; int cnum[maxn<<2],L[maxn<<2],R[maxn<<2],tag[maxn<<2],col[maxn]; struct Edge { int u,v,next; Edge() {} Edge(int _1,int _2,int _3) : u(_1),v(_2),next(_3) {} }e[maxn<<1]; void addEdge(int a,int b) { e[++ce]=Edge(a,b,first[a]);first[a]=ce; e[++ce]=Edge(b,a,first[b]);first[b]=ce; } void dfs(int now,int pa) { size[now]=1; for(int i=first[now];i!=-1;i=e[i].next) if(e[i].v!=pa) { fa[e[i].v]=now;deep[e[i].v]=deep[now]+1; dfs(e[i].v,now); size[now]+=size[e[i].v]; } } void divide(int now,int chain) { id[now]=++es;bl[now]=chain; int maxs=0; for(int i=first[now];i!=-1;i=e[i].next) if(fa[now]!=e[i].v && size[e[i].v]>size[maxs])maxs=e[i].v; if(!maxs)return; divide(maxs,chain); for(int i=first[now];i!=-1;i=e[i].next) if(fa[now]!=e[i].v && e[i].v!=maxs)divide(e[i].v,e[i].v); } void pushdown(int l,int r,int o) { if(tag[o]==-1 || l==r)return; int mid=(l+r)>>1,lo=o<<1,ro=lo|1; tag[lo]=tag[ro]=tag[o]; cnum[lo]=cnum[ro]=1; L[lo]=R[lo]=L[ro]=R[ro]=L[o]; tag[o]=-1; } void pushup(int l,int r,int o) { int mid=(l+r)>>1,lo=o<<1,ro=lo|1; L[o]=L[lo];R[o]=R[ro]; cnum[o]=cnum[lo]+cnum[ro]; if(R[lo]==L[ro])cnum[o]--; } void build(int l,int r,int o) { if(l==r) { cnum[o]=1;tag[o]=-1; L[o]=R[o]=col[l]; return; } int mid=(l+r)>>1,lo=o<<1,ro=lo|1; build(l,mid,lo);build(mid+1,r,ro); pushup(l,r,o);tag[o]=-1; } void update(int l,int r,int o,int a,int b,int c) { if(l==a && r==b) { tag[o]=c;cnum[o]=1; L[o]=R[o]=c; return; } pushdown(l,r,o); int mid=(l+r)>>1,lo=o<<1,ro=lo|1; if(b<=mid)update(l,mid,lo,a,b,c); else if(a>mid)update(mid+1,r,ro,a,b,c); else { update(l,mid,lo,a,mid,c); update(mid+1,r,ro,mid+1,b,c); } pushup(l,r,o); } int query(int l,int r,int o,int a,int b) { if(l==a && r==b)return cnum[o]; int mid=(l+r)>>1,lo=o<<1,ro=lo|1; pushdown(l,r,o); if(a>mid)return query(mid+1,r,ro,a,b); else if(b<=mid)return query(l,mid,lo,a,b); else { int ans=query(l,mid,lo,a,mid)+query(mid+1,r,ro,mid+1,b); return R[lo]==L[ro] ? ans-1 : ans; } } int special(int l,int r,int o,int a) { if(l==r)return L[o]; pushdown(l,r,o); int mid=(l+r)>>1,lo=o<<1,ro=lo|1; if(a<=mid)return special(l,mid,lo,a); else return special(mid+1,r,ro,a); } void paint(int a,int b,int c) { while(bl[a]!=bl[b]) { if(deep[bl[a]]<deep[bl[b]])swap(a,b); update(1,n,1,id[bl[a]],id[a],c); a=fa[bl[a]]; } if(id[a]>id[b])swap(a,b); update(1,n,1,id[a],id[b],c); } int Query(int a,int b) { int sum=0; while(bl[a]!=bl[b]) { if(deep[bl[a]]<deep[bl[b]])swap(a,b); sum+=query(1,n,1,id[bl[a]],id[a]); if(special(1,n,1,id[bl[a]])==special(1,n,1,id[fa[bl[a]]]))sum--; a=fa[bl[a]]; } if(id[a]>id[b])swap(a,b); sum+=query(1,n,1,id[a],id[b]); return sum; } int main() { mem(first,-1); n=read();m=read(); for(int i=1;i<=n;i++)w[i]=read(); for(int i=1;i<n;i++)a=read(),b=read(),addEdge(a,b); dfs(1,1);divide(1,1); for(int i=1;i<=n;i++)col[id[i]]=w[i]; build(1,n,1); while(m--) { cin>>tp;a=read();b=read(); if(tp=='C')c=read(),paint(a,b,c); else printf("%d ",Query(a,b)); } return 0; }