2243: [SDOI2011]染色
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),
如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Input
第一行包含2个整数n和m,分别表示节点数和操作数;
第二行包含n个正整数表示n个节点的初始颜色
下面 行每行包含两个整数x和y,表示x和y之间有一条无向边。
下面 行每行描述一个操作:
“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括a和b)都染成颜色c;
“Q a b”表示这是一个询问操作,询问节点a到节点b(包括a和b)路径上的颜色段数量。
Output
对于每个询问操作,输出一行答案。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
思路:
显然,这道题我们可以通过在整颗树上架一颗线段树、然后进行区间修改区间查询的操作。而询问a到b的路径,显然是a到lca(a,b)的路径和 b到lca(a,b)的路径的并。所以我们可以使用先树链剖分再在剖分序列上架线段树的方式维护这样的连续区间。通过在dfs2的时候给每个点压入队列确定顺序,再在LCA的时候对每段进行查询,即可得到每次的答案;
至于线段树的操作,本题显然是一道区间合并的线段树,我们可以维护如下三种信息:区间左端点颜色,区间右端点颜色、区间颜色段数目。在pushup时,若左子树的右端点和右子树的左端点颜色相同,则是同一段颜色,sum[p] = sum[ls]+sum[rs] - 1,否则是两段颜色,不需要-1 。值得注意的是,本题的颜色范围是[0,1e9],所以在打lazy_tag时,要注意laz的初始值,不能是0,。但是我选择把所有输入的颜色++,相当于区间平移,不会对答案产生影响。
另外,本题函数非常多,可以用结构体封装以增加代码清晰度,我选择了用class封装,但是可能会稍微慢一点。
代码如下
#include <cstring> #include <cstdio> #include <algorithm> #include <iostream> #include <cctype> #define ls p<<1 #define rs ls|1 #define lson l,mid,ls #define rson mid+1,r,rs #define im int mid = (l + r) >> 1 using namespace std; const int N = 1100000; int son[N],siz[N],fa[N],top[N],dep[N]; int idx[N],idy[N],cnt2; int to[N<<1],next[N<<1],pval[N],head[N],cnt1; int n,m; class ReadIn { private: inline char nc() { static char buf[100000], *p1, *p2; return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++; } public: inline int read() { int x=0;char ch=nc(); while(!isdigit(ch))ch=nc(); while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=nc();} return x; } inline char getc() { char ch=nc(); while(isspace(ch))ch=nc(); return ch; } }Rd; class SegmentTree { private: int lft[N<<2],rft[N<<2],sum[N<<2],laz[N<<2]; void pushdown(int p) { if(!laz[p]) return ; else { sum[ls]=sum[rs]=1; lft[ls]=laz[ls]=rft[ls]= lft[rs]=laz[rs]=rft[rs]=laz[p]; } laz[p]=0; return; } void pushup(int p) { sum[p]=sum[ls]+sum[rs]; if(rft[ls]==lft[rs])sum[p]--; lft[p]=lft[ls]; rft[p]=rft[rs]; } public: void build(int l,int r,int p) { if(l==r) { lft[p]=rft[p]=pval[idy[l]]; sum[p]=1; return ; } im; build(lson); build(rson); pushup(p); return ; } int query(int l,int r,int p,int x,int y) { pushdown(p); if(x<=l&&y>=r) { return sum[p]; } im; if(y<=mid) return query(lson,x,y); else if(x>mid) return query(rson,x,y); else { int re = query(lson,x,y)+query(rson,x,y);; if(rft[ls]==lft[rs])re--; return re; } } void change(int l,int r,int p,int x,int y,int c) { if(x<=l&&y>=r) { laz[p]=c; sum[p]=1; lft[p]=rft[p]=laz[p]; return; } im; pushdown(p); if(x<=mid) change(lson,x,y,c); if(y>mid) change(rson,x,y,c); pushup(p); } int find(int l,int r,int p,int x) { pushdown(p); if(l==r) return lft[p]; im; if(x<=mid) return find(lson,x); else return find(rson,x); } }Tr; class TreeChainDissection { public: void dfs1(int p) { dep[p]=dep[fa[p]]+1; siz[p]=1; for (int i = head[p];i; i = next[i] ) { if(to[i] != fa[p]) { fa[to[i]]=p; dfs1(to[i]); siz[p]+=siz[to[i]]; if(siz[to[i]]>siz[son[p]]) son[p]=to[i]; } } } void dfs2(int p,int t) { idx[p]=++cnt2; idy[cnt2]=p; top[p]=t; if(son[p]) dfs2(son[p],t); for(int i=head[p];i;i=next[i]) if(to[i]!=fa[p]&&to[i]!=son[p]) dfs2(to[i],to[i]); } int lcaq(int a,int b) { int ans=0; while(top[a]!=top[b]) { if(dep[top[a]]>dep[top[b]])swap(a,b); int upc = Tr.find(1,n,1,idx[fa[top[b]]]); int doc = Tr.find(1,n,1,idx[top[b]]); ans+= Tr.query(1,n,1,idx[top[b]],idx[b]); if(upc==doc) ans--; //printf("%d ",ans); b=fa[top[b]]; } if(dep[a]<dep[b])swap(a,b); ans+=Tr.query(1,n,1,idx[b],idx[a]); if(!ans)ans=1; return ans; } void lcac(int x,int y,int z) { while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]])swap(x,y); Tr.change(1,n,1,idx[top[y]],idx[y],z); y=fa[top[y]]; } if(dep[x]<dep[y])swap(x,y); Tr.change(1,n,1,idx[y],idx[x],z); } }Tcd; class Pre { private: inline void add_edge(int a,int b) { to[++cnt1] = b; next[cnt1] = head[a]; head[a] = cnt1; to[++cnt1] = a; next[cnt1] = head[b]; head[b] = cnt1; } public: void init() { n=Rd.read(),m=Rd.read(); int i,x,y; for(i=1 ;i <=n; i++) pval[i] = Rd.read() + 1; for(i=1 ;i < n; i++) { x=Rd.read(),y=Rd.read(); add_edge(x,y); } } }Pr; void solve() { Tcd.dfs1(1); Tcd.dfs2(1,1); Tr.build(1,n,1); int i,x,y,z; char opt; for(i=1;i<=m;i++) { opt=Rd.getc(); if(opt=='Q') { x=Rd.read(); y=Rd.read(); int ans=Tcd.lcaq(x,y); printf("%d ",ans); } else { x=Rd.read(), y=Rd.read(), z=Rd.read(); Tcd.lcac(x,y,z+1); } } } int main() { Pr.init(); solve(); }
欢迎来原博客看看 >原文链接<