分析
lxl大毒瘤。
感谢Ouuan等CNOIER提供了这么好的比赛。
这里只是把官方题解复述一遍,可以直接去看官方题解:点我。
考虑将问题转化为对于每个颜色,求出没有经过这个颜色的节点的路径有多少条,这问题的答案是:
[sum_{i=1}^{n}(n^2-sum_{G'}siz_{G'}^2)
]
其中(G')是所有不包含颜色为(i)的节点的极大连通块。
视颜色为(i)的节点为白点,其余为黑点,那么我们现在要做的就是维护一个数据结构,支持:
-
修改一个节点的颜色。
-
询问所有黑点的极大连通块的大小的平方和。
考虑使用LCT,如果每个黑点向父亲连边的话,那么真实的黑点的极大连通块就是每个连通块去掉根节点后得到的连通块,这样我们可以使用LCT维护子树信息的技巧维护。
注意(link(x))和(cut(x))函数的实现。(这里参考了标程)
时间复杂度为(O((n+m) log n))。
代码
#include <bits/stdc++.h>
#define rin(i,a,b) for(int i=(a);i<=(b);++i)
#define irin(i,a,b) for(int i=(a);i>=(b);--i)
#define trav(i,a) for(int i=head[a];i;i=e[i].nxt)
#define Size(a) (int)a.size()
#define pb push_back
#define mkpr std::make_pair
#define fi first
#define se second
#define lowbit(a) ((a)&(-(a)))
#define sqr(a) ((a)*(a))
typedef long long LL;
typedef std::pair<int,int> pii;
using std::cerr;
using std::endl;
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
const int MAXN=400005;
int n,m,ecnt,head[MAXN];
int c[MAXN],fa[MAXN];
LL ans[MAXN],cur;
std::vector<pii> vec[MAXN];
struct Edge{
int to,nxt;
}e[MAXN<<1];
inline void add_edge(int bg,int ed){
++ecnt;
e[ecnt].to=ed;
e[ecnt].nxt=head[bg];
head[bg]=ecnt;
}
struct lct{
int fa,ch[2];
int siz;
int isiz;
LL ich2;
bool tag;
}a[MAXN];
#define lc a[x].ch[0]
#define rc a[x].ch[1]
inline bool isroot(int x){
return a[a[x].fa].ch[0]!=x&&a[a[x].fa].ch[1]!=x;
}
inline void pushr(int x){
std::swap(lc,rc);
a[x].tag^=1;
}
inline void pushdown(int x){
if(!a[x].tag)return;
if(lc)pushr(lc);
if(rc)pushr(rc);
a[x].tag=false;
}
inline void pushup(int x){
a[x].siz=a[lc].siz+a[rc].siz+1+a[x].isiz;
}
inline void rotate(int x){
int y=a[x].fa,z=a[y].fa,d=a[y].ch[1]==x,w=a[x].ch[d^1];
if(!isroot(y))a[z].ch[a[z].ch[1]==y]=x;
a[x].ch[d^1]=y;
a[y].ch[d]=w;
if(w)a[w].fa=y;
a[y].fa=x;
a[x].fa=z;
pushup(y);
}
int top,sta[MAXN];
inline void splay(int x){
int y=x,z=0;
sta[top=1]=y;
while(!isroot(y))sta[++top]=y=a[y].fa;
while(top)pushdown(sta[top--]);
while(!isroot(x)){
y=a[x].fa,z=a[y].fa;
if(!isroot(y)){
if((a[z].ch[0]==y)==(a[y].ch[0]==x))rotate(y);
else rotate(x);
}
rotate(x);
}
pushup(x);
}
inline void access(int x){
for(int y=0;x;x=a[y=x].fa){
splay(x);
a[x].isiz+=a[rc].siz;
a[x].ich2+=sqr(1ll*a[rc].siz);
rc=y;
a[x].isiz-=a[rc].siz;
a[x].ich2-=sqr(1ll*a[rc].siz);
pushup(x);
}
}
inline void makeroot(int x){
access(x);
splay(x);
pushr(x);
}
inline int findroot(int x){
access(x);
splay(x);
while(pushdown(x),lc)x=lc;
splay(x);
return x;
}
inline void split(int x,int y){
makeroot(x);
access(y);
splay(y);
}
inline void link(int x,int y){
split(x,y);
a[x].fa=y;
a[y].isiz+=a[x].siz;
a[y].ich2+=sqr(1ll*a[x].siz);
pushup(y);
}
inline void cut(int x,int y){
split(x,y);
a[y].ch[0]=0;
a[x].fa=0;
pushup(y);
}
inline void link(int x){
int y=fa[x];
access(x);
splay(x);
cur-=a[x].ich2+sqr(1ll*a[rc].siz);
int z=findroot(y);
access(y);
splay(z);
cur-=sqr(1ll*a[a[z].ch[1]].siz);
splay(y);
a[x].fa=y;
a[y].isiz+=a[x].siz;
a[y].ich2+=sqr(1ll*a[x].siz);
pushup(y);
access(x);
splay(z);
cur+=sqr(1ll*a[a[z].ch[1]].siz);
}
inline void cut(int x){
int y=fa[x];
access(x);
cur+=a[x].ich2;
int z=findroot(y);
access(x);
splay(z);
cur-=sqr(1ll*a[a[z].ch[1]].siz);
splay(x);
a[lc].fa=0;
lc=0;
pushup(x);
access(y);
splay(z);
cur+=sqr(1ll*a[a[z].ch[1]].siz);
}
#undef lc
#undef rc
void dfs(int x,int pre){
fa[x]=pre;
trav(i,x){
int ver=e[i].to;
if(ver==pre)continue;
dfs(ver,x);
}
a[x].fa=pre;
a[pre].isiz+=a[x].siz;
a[pre].ich2+=sqr(1ll*a[x].siz);
pushup(pre);
}
int main(){
n=read(),m=read();
rin(i,1,n){
c[i]=read();
vec[c[i]].pb(mkpr(i,0));
}
rin(i,2,n){
int u=read(),v=read();
add_edge(u,v);
add_edge(v,u);
}
rin(i,1,n+1)a[i].siz=1;
dfs(1,n+1);
rin(i,1,m){
int x=read(),y=read();
if(c[x]==y)continue;
vec[c[x]].pb(mkpr(-x,i));
c[x]=y;
vec[c[x]].pb(mkpr(x,i));
}
cur=sqr(1ll*n);
rin(i,1,n){
rin(j,0,Size(vec[i])-1){
int x=vec[i][j].fi;
if(x>0)cut(x);
else link(-x);
ans[vec[i][j].se]+=sqr(1ll*n)-cur;
ans[j==Size(vec[i])-1?m+1:vec[i][j+1].se]-=sqr(1ll*n)-cur;
}
irin(j,Size(vec[i])-1,0){
int x=vec[i][j].fi;
if(x>0)link(x);
else cut(-x);
}
}
rin(i,1,m)ans[i]+=ans[i-1];
rin(i,0,m)printf("%lld
",ans[i]);
return 0;
}