很裸的一道树链剖分= =直接上代码
PS:调不了几次就过了真开心~~~
CODE:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
#define maxn 100010
#define maxm 200010
struct edges{
int to,next;
}edge[maxm];
int next[maxn],ednum;
int addedge(int from,int to){
ednum++;
edge[ednum*2]=(edges){to,next[from]};
edge[ednum*2+1]=(edges){from,next[to]};
next[from]=ednum*2;
next[to]=ednum*2+1;
return 0;
}
bool b[maxn];
int dep[maxn],sum[maxn],pos[maxn],color[maxn],add[maxn],fa[maxn],pre[maxn],ch[maxn];
int cnt;
int dfs(int u){
dep[u]=dep[fa[u]]+1;
b[u]=1;sum[u]=1;
for (int i=next[u];i;i=edge[i].next)
if (!b[edge[i].to]) {
fa[edge[i].to]=u;
dfs(edge[i].to);
sum[u]+=sum[edge[i].to];
ch[u]=sum[ch[u]]<sum[edge[i].to]?edge[i].to:ch[u];
}
return 0;
}
int heavy(int u,bool bo){
b[u]=1;
if (bo) pre[u]=u;
else pre[u]=pre[fa[u]];
add[pos[++cnt]=u]=cnt;
if (ch[u]) heavy(ch[u],false);
for (int i=next[u];i;i=edge[i].next)
if (!b[edge[i].to]) heavy(edge[i].to,true);
return 0;
}
struct bake{
int l,r,sum;
};
bake operator +(const bake &x,const bake &y){
bake ans;
ans=(bake){x.l==-1?y.l:x.l,y.r==-1?x.r:y.r,x.sum+y.sum};
if (x.r==y.l) ans.sum--;
return ans;
}
struct node{
int l,r;bake b;bool flag;
}t[maxn*8];
int set(int x,int col){
t[x].b.l=t[x].b.r=col;t[x].b.sum=1;t[x].flag=1;
return 0;
}
int build(int x,int l,int r){
t[x].l=l;t[x].r=r;t[x].flag=0;
if (l==r) {set(x,color[pos[l]]);return 0;}
build(x<<1,l,(l+r)>>1);build((x<<1)+1,((l+r)>>1)+1,r);
t[x].b=t[x<<1].b+t[(x<<1)+1].b;
return 0;
}
int pushback(int x){
if (t[x].flag){
t[x].flag=0;
set(x<<1,t[x].b.l);
set((x<<1)+1,t[x].b.l);
}
return 0;
}
int cha(int x,int x1,int y1,int col){
pushback(x);
int l=t[x].l,r=t[x].r;
if (y1<l||x1>r) return 0;
if (l>=x1&&r<=y1) {set(x,col);return 0;}
cha(x<<1,x1,y1,col);cha((x<<1)+1,x1,y1,col);
t[x].b=t[x<<1].b+t[(x<<1)+1].b;
}
int change(int l,int r,int co){
for (;;){
if (pre[l]==pre[r]){
if (dep[l]>dep[r]) swap(l,r) ;
cha(1,add[l],add[r],co);
return 0;
}else {
if (dep[pre[l]]<dep[pre[r]]) swap(l,r);
cha(1,add[pre[l]],add[l],co);
l=fa[pre[l]];
}
}
return 0;
}
bake get(int x,int x1,int y1){
pushback(x);
int l=t[x].l,r=t[x].r;
if (y1<l||x1>r) return (bake){-1,-1,0};
if (l>=x1&&r<=y1) return t[x].b;
return (get(x<<1,x1,y1)+get((x<<1)+1,x1,y1));
}
int query(int l,int r){
bake la=(bake){-1,-1,0},ra=(bake){-1,-1,0};
for (;;){
if (pre[l]==pre[r]){
if (dep[l]>dep[r]) {swap(l,r) ;swap(la,ra);}
ra=get(1,add[l],add[r])+ra;
int ans=ra.sum+la.sum;
if (ra.l==la.l) ans--;
return ans;
}else {
if (dep[pre[l]]<dep[pre[r]]) {swap(la,ra);swap(l,r);}
la=get(1,add[pre[l]],add[l])+la;
l=fa[pre[l]];
}
}
}
int main(){
int n,m;
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++) scanf("%d",color+i);
for (int i=1;i<n;i++) {
int x,y;
scanf("%d%d",&x,&y);
addedge(x,y);
}
memset(b,0,sizeof(b));
dfs(1);
memset(b,0,sizeof(b));
heavy(1,true);
build(1,1,n);
for (int i=1;i<=m;i++) {
char s[2];int x,y,z;
scanf("%s",s);
if (s[0]=='C') {
scanf("%d%d%d",&x,&y,&z);
change(x,y,z);
}else{
scanf("%d%d",&x,&y);
printf("%d
",query(x,y));
}
}
return 0;
}