一下午搭进去了。。
最后发现少写了个pushup???
有个关键的细节就是每次可合并两个区间的时候要判断一下ls的最左边的颜色和rs最右边的颜色一样不一样,一样就ans--
然后就update和query的时候把链分成两节弄,否则很麻烦。
#include <iostream> #include <cstdio> #include <cstring> using namespace std; const int N=100005; int n,dfn[N],top[N],siz[N],son[N],nxt[N<<1],to[N<<1],a[N],tim,dep[N],fa[N],head[N],ecnt,rnk[N]; struct Seg{ int l,r,lc,rc,cnt,lazy; Seg(){lazy=-1;} }t[N<<3]; inline void add(int bg,int ed) { nxt[++ecnt]=head[bg]; to[ecnt]=ed; head[bg]=ecnt; } void dfs(int x) { siz[x]=1; for(int i=head[x];i;i=nxt[i]) { if(to[i]!=fa[x]) { fa[to[i]]=x; dep[to[i]]=dep[x]+1; dfs(to[i]); siz[x]+=siz[to[i]]; if(siz[to[i]]>siz[son[x]]) son[x]=to[i]; } } } void dfs(int x,int Top) { top[x]=Top;dfn[x]=++tim;rnk[tim]=x; if(!son[x]) return; dfs(son[x],Top); for(int i=head[x];i;i=nxt[i]) { if(fa[x]!=to[i]&&to[i]!=son[x]) dfs(to[i],to[i]); } } #define ls (cur<<1) #define rs (cur<<1|1) #define mid (l+r>>1) inline void pushup(int cur) { t[cur].lc=t[ls].lc; t[cur].rc=t[rs].rc; t[cur].cnt=t[ls].cnt+t[rs].cnt-(t[ls].rc==t[rs].lc); } void build(int l,int r,int cur) { t[cur].l=l,t[cur].r=r;t[cur].lazy=-1; if(l==r) { t[cur].lc=t[cur].rc=a[rnk[l]];t[cur].cnt=1; //cerr<<"BUILD:"<<rnk[l]<<' '<<t[cur].lc<<endl; return; } build(l,mid,ls); build(mid+1,r,rs); pushup(cur); } inline void pushdown(int cur) { if(t[cur].lazy!=-1) { t[ls].cnt=t[cur].cnt=1;t[cur].lc=t[cur].rc=t[ls].lc=t[ls].rc=t[ls].lazy=t[cur].lazy; t[rs].cnt=1;t[rs].lc=t[rs].rc=t[rs].lazy=t[cur].lazy; t[cur].lazy=-1; } return; } void modify(int ql,int qr,int l,int r,int cur,int c) { if(ql>qr) return; pushdown(cur); if(ql<=l&&r<=qr) { t[cur].cnt=1;t[cur].lazy=t[cur].lc=t[cur].rc=c;return; } int md=(l+r)>>1; if(ql<=mid) modify(ql,qr,l,mid,ls,c); if(mid<qr) modify(ql,qr,mid+1,r,rs,c); pushup(cur); } int query(int ql,int qr,int l,int r,int cur) { if(ql>qr) return 0; pushdown(cur); if(ql<=l&&r<=qr) return t[cur].cnt; //int md=l+r>>1; if(qr<=mid) return query(ql,qr,l,mid,ls); else if(ql>mid) return query(ql,qr,mid+1,r,rs); else { int x=query(ql,qr,l,mid,ls),y=query(ql,qr,mid+1,r,rs); return x+y-(t[ls].rc==t[rs].lc); } } void _add(int x,int y,int c) { while(top[x]!=top[y]) { modify(dfn[top[x]],dfn[x],1,n,1,c); x=fa[top[x]]; } modify(dfn[y],dfn[x],1,n,1,c); } int col(int cur,int x) { pushdown(cur); if(t[cur].l==t[cur].r) return t[cur].lc; int md=(t[cur].l+t[cur].r)>>1; if(x<=md) return col(ls,x);else return col(rs,x); } int query(int x,int y) { int ans=0; while(top[x]!=top[y]) { ans+=query(dfn[top[x]],dfn[x],1,n,1); if(col(1,dfn[fa[top[x]]])==col(1,dfn[top[x]])) ans--; x=fa[top[x]]; } ans+=query(dfn[y],dfn[x],1,n,1); return ans; } int LCA(int x,int y) { while(top[x]!=top[y]) (dep[top[x]]>=dep[top[y]])? x=fa[top[x]]:y=fa[top[y]];; return dep[x]<dep[y]?x:y; } int m; int main() { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&a[i]); for(int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x); dfs(1); dfs(1,1); build(1,n,1); char opt[5];int a,b,c; while(m--) { scanf("%s",opt); if(opt[0]=='C') { scanf("%d%d%d",&a,&b,&c); int lca=LCA(a,b); _add(a,lca,c),_add(b,lca,c); } else if(opt[0]=='Q') { scanf("%d%d",&a,&b); int lca=LCA(a,b); printf("%d ",query(a,lca)+query(b,lca)-1); } } return 0; }