我丢 之前sun在某校集训给我看过 当时也没想起来 今天补省集的锅的时候发现 wok这题我还听过?!
身败名裂.jpg (可是你记性不好这事情不已经人尽皆知了吗?
咳咳 回归正题
考虑对于两个同色的点:
1)不构成祖先关系
那么两个子树里的点都不能选 相当于矩形覆盖
2)构成祖先关系
祖先刨掉一个子树,儿子子树不能选
拆成两个矩形
最后考虑统计答案,发现对称做然后(总点数-答案)/2就是答案
(因为对角线上的点总是合法的 所以要加上qwq)
然后就是矩形的并数点了 直接扫描线+线段树就好了
//Love and Freedom. #include<cstdio> #include<algorithm> #include<cstring> #include<cmath> #include<vector> #include<cassert> #define ll long long #define inf 20021225 #define N 100010 #define pb push_back using namespace std; int read() { int s=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();} while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar(); return f*s; } //--------- tree ----------- struct edge{int to,lt;}e[N<<1]; int in[N],cnt,dep[N],sz[N],f[N][18],dfn[N],tms,idfn[N]; void add(int x,int y) { e[++cnt].to=y; e[cnt].lt=in[x]; in[x]=cnt; e[++cnt].to=x; e[cnt].lt=in[y]; in[y]=cnt; } void dfs(int x) { dfn[x]=++tms; idfn[tms]=x; sz[x]=1; for(int i=1;i<18;i++) f[x][i]=f[f[x][i-1]][i-1]; for(int i=in[x];i;i=e[i].lt) { int y=e[i].to; if(y==f[x][0]) continue; dep[y]=dep[x]+1; f[y][0]=x; dfs(y); sz[x]+=sz[y]; } } int LCA(int x,int y) { if(dep[x]<dep[y]) swap(x,y); int len=dep[x]-dep[y]; for(int i=0;i<18;i++) if(len>>i&1) x=f[x][i]; if(x==y) return x; for(int i=17;~i;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; return f[x][0]; } int col[N]; struct node { int xd,xu,y,v; }p[N*210]; bool operator<(node a,node b){return a.y<b.y;} //----------- SGT ------------ // 0->not fully covered >0 -> fully covered #define ls x<<1 #define rs x<<1|1 int s[N<<2],tag[N<<2]; void insert(int x,int l,int r,int LL,int RR,int v) { if(RR<LL) return; if(LL<=l&&RR>=r) { tag[x]+=v; if(tag[x]>0) s[x]=r-l+1; else if(l==r) s[x]=0; else s[x]=s[ls]+s[rs]; return; } int mid=l+r>>1; if(LL<=mid) insert(ls,l,mid,LL,RR,v); if(mid<RR) insert(rs,mid+1,r,LL,RR,v); if(!tag[x]) s[x]=s[ls]+s[rs]; } vector<int> cc[N]; int tot; int main() { //freopen("tree3-4.in","r",stdin); int n=read(); for(int i=1;i<=n;i++) col[i]=read(),cc[col[i]].pb(i); for(int i=1;i<n;i++) add(read(),read()); dfs(1); for(int i=1;i<=n;i++) for(int j=0;j<cc[i].size();j++) for(int k=j+1;k<cc[i].size();k++) { int x=cc[i][j],y=cc[i][k],z=LCA(x,y); if(z==x||z==y) { if(z==y) swap(x,y); int len=dep[y]-dep[x]-1; x=y; for(int i=0;i<18;i++) if(len>>i&1) x=f[x][i]; p[++tot]=(node){1,dfn[x]-1,dfn[y],1}; p[++tot]=(node){1,dfn[x]-1,dfn[y]+sz[y],-1}; p[++tot]=(node){dfn[x]+sz[x],n,dfn[y],1}; p[++tot]=(node){dfn[x]+sz[x],n,dfn[y]+sz[y],-1}; p[++tot]=(node){dfn[y],dfn[y]+sz[y]-1,1,1}; p[++tot]=(node){dfn[y],dfn[y]+sz[y]-1,dfn[x],-1}; p[++tot]=(node){dfn[y],dfn[y]+sz[y]-1,dfn[x]+sz[x],1}; p[++tot]=(node){dfn[y],dfn[y]+sz[y]-1,n+1,-1}; } else { p[++tot]=(node){dfn[x],dfn[x]+sz[x]-1,dfn[y],1}; p[++tot]=(node){dfn[x],dfn[x]+sz[x]-1,dfn[y]+sz[y],-1}; p[++tot]=(node){dfn[y],dfn[y]+sz[y]-1,dfn[x],1}; p[++tot]=(node){dfn[y],dfn[y]+sz[y]-1,dfn[x]+sz[x],-1}; } } sort(p+1,p+tot+1); ll ans=0; for(int i=1;i<=tot;i++) { insert(1,1,n,p[i].xd,p[i].xu,p[i].v); if(i!=tot && p[i].y!=p[i+1].y) ans+=1ll*(p[i+1].y-p[i].y)*s[1]; } printf("%lld ",(1ll*n*(n+1)-ans)>>1); return 0; }