题目描述
这棵树上有n个节点,由n−1条树枝相连。初始时树上都挂了一个毒瘤,颜色为ci。接下来Salamander将会进行q个操作。
Salamander有时会修改树上某个点到另外一个点的简单路径上所有毒瘤的颜色。
对于给定的树上某个点集S,Salamander还定义了某个点的权值:
(W_i=sum_{jin S}T(i,j))
其中T(i,j)表示i到j的路径上毒瘤颜色的段数,比如i到j的路径上毒瘤颜色为1223312时,颜色段数为5。
Salamander对树上的毒瘤们的状态很感兴趣,所以有时会指定树上m个节点作为点集,询问这m个节点的权值。
输入输出格式
输入格式:
第一行包括两个正整数n、q,表示树上的节点数和操作个数。
第二行包括用空格隔开的n个正整数ci,表示树上每个节点初始的毒瘤颜色。
接下来n−1行,每行两个正整数u、v,表示树上有一条连接u和v的边。
接下来描述q个操作:
- 1 若给出的第一个整数等于1,那么接下来将会有三个正整数u、v、y,表示将树上编号为u的点到编号为v的点的简单路径上的毒瘤颜色全都改为y。
- 2 若给出的第一个整数等于2,那么接下来将会有一个正整数m,表示指定的树上节点个数。下一行将会有m个用空格隔开的互不相同的正整数,表示当前询问给定的m个节点。
输出格式:
若干行,对于每个2操作输出对应的答案。
Solution
树剖维护修改。
对于询问,很容易想到虚树。
对于虚树上的边,边权定义为去掉(dep)小的那个点所在的颜色块剩余的块的数量。
这样的话,一条路径的权值就是边权之和(+1)。
所以,建完虚树之后,直接二次换根(dp)就好了。
(当然你也可以点分治,但是那样代码大概会破300行。。)
所以这玩意就是一模板大集合。。
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('
');}
const int maxn = 2e5+10;
int n,col[maxn],m,sz[maxn],dfn[maxn],f[maxn],dep[maxn],re[maxn];
#define ls p<<1
#define rs p<<1|1
#define mid ((l+r)>>1)
struct data {
int lcol,rcol,sum;
data operator + (const data &rhs) const {
if(!sum) return rhs;
if(!rhs.sum) return *this;
return (data){lcol,rhs.rcol,sum+rhs.sum-(rcol==rhs.lcol)};
}
};
struct Segment_Tree {
int cov[maxn<<2];
data t[maxn<<2];
void push_tag(int p,int x) {t[p].sum=1,t[p].lcol=t[p].rcol=x,cov[p]=x;}
void pushdown(int p) {
if(!cov[p]) return ;
push_tag(ls,cov[p]),push_tag(rs,cov[p]);
cov[p]=0;
}
void modify(int p,int l,int r,int x,int y,int z) {
if(x<=l&&r<=y) return push_tag(p,z),void();
pushdown(p);
if(x<=mid) modify(ls,l,mid,x,y,z);
if(y>mid) modify(rs,mid+1,r,x,y,z);
t[p]=t[ls]+t[rs];
}
data query(int p,int l,int r,int x,int y) {
if(x<=l&&r<=y) return t[p];
pushdown(p);
if(y<=mid) return query(ls,l,mid,x,y);
else if(x>mid) return query(rs,mid+1,r,x,y);
else return query(ls,l,mid,x,y)+query(rs,mid+1,r,x,y);
}
void build(int p,int l,int r) {
if(l==r) return t[p].sum=1,t[p].lcol=t[p].rcol=col[re[l]],void();
build(ls,l,mid),build(rs,mid+1,r),t[p]=t[ls]+t[rs];
}
}SGT;
struct Heavy_Light_Decomposation {
int head[maxn],tot,hs[maxn],top[maxn],dfn_cnt;
struct edge{int to,nxt;}e[maxn<<1];
void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}
void dfs(int x,int fa) {
sz[x]=1,f[x]=fa,dep[x]=dep[fa]+1;
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa) {
dfs(e[i].to,x);sz[x]+=sz[e[i].to];
if(sz[e[i].to]>sz[hs[x]]) hs[x]=e[i].to;
}
}
void dfs2(int x) {
if(hs[f[x]]==x) top[x]=top[f[x]];
else top[x]=x;
dfn[x]=++dfn_cnt;re[dfn_cnt]=x;
if(hs[x]) dfs2(hs[x]);
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=f[x]&&e[i].to!=hs[x]) dfs2(e[i].to);
}
int lca(int x,int y) {
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=f[top[x]];
}if(dep[x]>dep[y]) swap(x,y);return x;
}
void modify(int x,int y,int z) {//puts("OK");
while(top[x]!=top[y]) {
if(dep[top[x]]<dep[top[y]]) swap(x,y);
SGT.modify(1,1,n,dfn[top[x]],dfn[x],z);
x=f[top[x]];
}if(dep[x]>dep[y]) swap(x,y);
SGT.modify(1,1,n,dfn[x],dfn[y],z);
}
void print(data x) {printf("Print :: %d %d %d
",x.sum,x.lcol,x.rcol);}
int query(int x,int t) {
data ans;int bo=1;
while(top[x]!=top[t]) {
data res=SGT.query(1,1,n,dfn[top[x]],dfn[x]);
if(bo) bo=0,ans=res;else ans=res+ans;
x=f[top[x]];
}
data res=SGT.query(1,1,n,dfn[t],dfn[x]);
if(bo) bo=0,ans=res;else ans=res+ans;
return ans.sum;
}
void debug(int x,int fa) {
//printf("Debug :: %d %d
",x,SGT.query(1,1,n,dfn[x],dfn[x]).lcol);
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa) debug(e[i].to,x);
}
}HLD;
int cmp(int x,int y) {return dfn[x]<dfn[y];}
int id[maxn];
int cmp2(int x,int y) {return id[x]<id[y];}
struct Virtual_Tree {
int head[maxn],tot,c[maxn],f[maxn],g[maxn],vis[maxn];
struct edge{int to,nxt,w;}e[maxn<<1];
void add(int u,int v,int w) {e[++tot]=(edge){v,head[u],w},head[u]=tot;}
void ins(int u,int v,int w) {add(u,v,w),add(v,u,w);}
int in[maxn],k,use[maxn],used,sta[maxn],top;
void build() {
sta[++top]=1,use[++used]=1;
for(int i=1;i<=k;i++) {
if(in[i]==1) continue;
int t=HLD.lca(in[i],sta[top]),pre=-1;
while(dfn[sta[top]]>dfn[t]&&dfn[sta[top]]<dfn[t]+sz[t]) {
if(pre!=-1) ins(sta[top],pre,HLD.query(pre,sta[top])-1);
pre=sta[top],use[++used]=sta[top],top--;
}
if(pre!=-1) ins(t,pre,HLD.query(pre,t)-1);
if(sta[top]!=t) sta[++top]=t;
sta[++top]=in[i];
}
int pre=-1;
while(top) {
if(pre!=-1) ins(sta[top],pre,HLD.query(pre,sta[top])-1);
pre=sta[top],use[++used]=sta[top],top--;
}
}
void dfs(int x,int fa) {
if(vis[x]) c[x]=1;else c[x]=0;
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa) {
//printf("Dfs :: %d %d %d
",x,e[i].to,e[i].w);
dfs(e[i].to,x);c[x]+=c[e[i].to];
f[x]+=e[i].w*c[e[i].to]+f[e[i].to];
}
}
void clear() {
for(int i=1;i<=used;i++) {
int t=use[i];
head[t]=id[t]=vis[t]=f[t]=g[t]=c[t]=0;
}tot=used=top=0;
}
void dfs2(int x,int fa) {
for(int i=head[x];i;i=e[i].nxt)
if(e[i].to!=fa) {
int v=e[i].to;
g[v]=g[x]-e[i].w*c[v]+e[i].w*(k-c[v]);
dfs2(e[i].to,x);
}
}
void solve() {
read(k);for(int i=1;i<=k;i++) read(in[i]),vis[in[i]]=1,id[in[i]]=i;
sort(in+1,in+k+1,cmp);build();
dfs(1,0);g[1]=f[1];dfs2(1,0);
sort(in+1,in+k+1,cmp2);
for(int i=1;i<=k;i++) printf("%d ",g[in[i]]+k);puts("");
clear();
//printf("
-------
");
}
}VT;
int main() {
read(n),read(m);
for(int i=1;i<=n;i++) read(col[i]);
for(int i=1,x,y;i<n;i++) read(x),read(y),HLD.ins(x,y);
HLD.dfs(1,0),HLD.dfs2(1),SGT.build(1,1,n);
for(int i=1;i<=m;i++) {
int op,u,v,y;read(op);
if(op==1) read(u),read(v),read(y),HLD.modify(u,v,y);
else VT.solve();
}
//write(HLD.query(7,1));
//HLD.debug(1,0);
return 0;
}