自己独立想出来的,超级开心
一开始想的是对于每一个点分别算这个点对答案的贡献.
但是呢,我们发现由于每一条路径的贡献是该路径颜色种类数,而每个颜色可能出现多次,所以这样就特别不好算贡献.
那么,还是上面那句话,由于算的是颜色种类,所以我们可以对每一个颜色种类单独算贡献.
即不以点为单位去算,而是以颜色种类为单位去算.
假设要算颜色为 $r$ 的贡献,那么必须保证每一个路径至少有一个端点在颜色 $r$ 构成的连通块中.
这句话等同于不能出现两个端点都在非 $r$ 连通块的路径,即 $n^2-sum_{col[i] eq r}size[i]^2$
对于每一个颜色都这么算就好了 ~
具体的话需要离线+撤销+LCT维护子树信息(就是那个平方和)
然后还要用到那个点权转边权,每次只删除和父亲连边的那个套路 ~
code:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 600003 #define LL long long #define lson t[x].ch[0] #define rson t[x].ch[1] #define setIO(s) freopen(s".in","r",stdin) ,freopen(s".out","w",stdout) using namespace std; LL ans,re[N]; int edges; int fa[N],hd[N],to[N<<1],nex[N<<1],val[N],size[N],col[N],is[N]; void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } struct data { int u,v,tim; data(int u=0,int v=0,int tim=0):u(u),v(v),tim(tim){} }; vector<data>G[N]; struct node { LL sqr; int ch[2],rev,f,siz,son; }t[N]; int get(int x) { return t[t[x].f].ch[1]==x; } int isrt(int x) { return !(t[t[x].f].ch[0]==x||t[t[x].f].ch[1]==x); } void pushup(int x) { t[x].siz=t[lson].siz+t[rson].siz+t[x].son+1; } void rotate(int x) { int old=t[x].f,fold=t[old].f,which=get(x); if(!isrt(old)) t[fold].ch[t[fold].ch[1]==old]=x; t[old].ch[which]=t[x].ch[which^1],t[t[old].ch[which]].f=old; t[x].ch[which^1]=old,t[old].f=x,t[x].f=fold; pushup(old),pushup(x); } void splay(int x) { int u=x,fa; for(;!isrt(u);u=t[u].f); for(u=t[u].f;(fa=t[x].f)!=u;rotate(x)) { if(t[fa].f!=u) { rotate(get(fa)==get(x)?fa:x); } } } void Access(int x) { for(int y=0;x;y=x,x=t[x].f) { splay(x); if(rson) { t[x].son+=t[rson].siz; t[x].sqr+=(LL)t[rson].siz*t[rson].siz; } if(y) { t[x].son-=t[y].siz; t[x].sqr-=(LL)t[y].siz*t[y].siz; } rson=y; pushup(x); } } void link(int x,int y) { Access(x),splay(x); t[y].f=x; t[x].son+=t[y].siz; t[x].sqr+=(LL)t[y].siz*t[y].siz; pushup(x); } // x 与 x 的父亲 void cut(int x) { Access(x),splay(x); if(lson) { t[lson].f=0; lson=0; pushup(x); } } int findroot(int x) { Access(x),splay(x); while(lson) x=lson; return x; } void turn_0(int x) { Access(x),splay(x); int now=t[x].siz; ans-=t[x].sqr; if(fa[x]) link(fa[x],x); int p=findroot(x); splay(p); is[x]=0; p=is[p]?t[p].ch[1]:p; int ori=t[p].siz; ans-=(LL)(ori-now)*(ori-now); ans+=(LL)ori*ori; } void turn_1(int x) { int p=findroot(x); splay(p); p=is[p]?t[p].ch[1]:p; int ori=t[p].siz; ans-=(LL)ori*ori; cut(x); int now=t[x].siz; ans+=(LL)(ori-now)*(ori-now); ans+=(LL)t[x].sqr; is[x]=1; } void dfs(int u,int ff) { size[u]=1; fa[u]=t[u].f=ff; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff) continue; dfs(v,u); size[u]+=size[v]; t[u].son+=size[v]; t[u].sqr+=(LL)size[v]*size[v]; } pushup(u); } int main() { // setIO("input"); int i,j,n,m,mx=0; scanf("%d%d",&n,&m); for(i=1;i<=n;++i) { scanf("%d",&col[i]); val[i]=col[i]; mx=max(mx,col[i]); G[val[i]].push_back(data(i,val[i],0)); } for(i=1;i<n;++i) { int u,v; scanf("%d%d",&u,&v); add(u,v),add(v,u); } dfs(1,0); for(i=1;i<=m;++i) { int u,v; scanf("%d%d",&u,&v); mx=max(mx,v); if(val[u]==v) continue; G[val[u]].push_back(data(u,v,i)); // val[u]->v G[v].push_back(data(u,v,i)); // ?->v val[u]=v; } for(i=1;i<=mx;++i) { ans=(LL)n*n; LL pre=0; for(j=0;j<G[i].size();++j) { if(G[i][j].v==i) // 别的变成 i (0->1) { turn_1(G[i][j].u); } else // i 变成别的 (1->0) { turn_0(G[i][j].u); } re[G[i][j].tim]-=pre; re[G[i][j].tim]+=(LL)n*n-ans; pre=(LL)n*n-ans; } for(j=G[i].size()-1;j>=0;--j) { if(G[i][j].v==i) // 别的变成 i (0->1) { turn_0(G[i][j].u); } else // i 变成别的 (1->0) { turn_1(G[i][j].u); } } } printf("%lld ",re[0]); for(i=1;i<=m;++i) re[i]+=re[i-1], printf("%lld ",re[i]); return 0; }