题目链接:https://codeforces.com/problemset/problem/1444/C
第一想法就是暴力枚举所有两个点的子图然后判断二分图
正难则反,考虑补集转化,即不合法的方案数,有两种情况:
- 存在单色环
- 存在双色环
所以只需要统计出不含环的单一颜色的数量 (cnt) 和不合法的双色环数量 (k)
最终答案就是 $ frac{cnt * (cnt - 1)}{2} - k$
统计答案使用并查集判断是否合法即可,
先将合法的单色边连好,然后依次加另一种颜色的边(边提前排好序保证相同颜色的被连续枚举),
每统计完一种颜色都要撤销掉这些操作(也即回到历史版本)
可持久化并查集即可
同时维护奇偶性需要用到扩展域并查集,将一个点拆成两个,(x_self) 和 (x_another),
如果有边则将 (x_self, y_another) 和 (x_another, y_self) 连到一起,表示两点不在同一队伍里,
如果 (x_self, y_self) 已经在同一集合里,则不合法
坑点:如果使用带权并查集的话,一定要路径压缩,而路径压缩在可持久化并查集中并不容易实现,所以选择扩展域并查集
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<stack>
#include<queue>
using namespace std;
typedef long long ll;
const int maxn = 500010;
int n, m, k, cnt, tot; ll ans;
int c[maxn], rt[maxn * 5];
struct E{
int u, v, uc, vc;
bool operator < (const E &a) const {
if(uc == a.uc){
return vc < a.vc;
}
return uc < a.uc;
}
}e[maxn];
int fa[maxn * 2], ran[maxn * 2];
int vis[maxn];
struct Node{
int lc, rc;
int fa, ran;
}t[maxn * 50];
int fi(int x){ return fa[x] == x ? x : fi(fa[x]); }
void unite(int x, int y){
x = fi(x), y = fi(y);
if(x == y) return;
if(ran[x] < ran[y]){
fa[x] = y;
} else{
fa[y] = x;
if(ran[x] == ran[y]) ++ran[x];
}
}
void solve_1(){
for(int i = 1 ; i <= n + n ; ++i) fa[i] = i, ran[i] = 0;
for(int i = 1 ; i <= m ; ++i){
if(e[i].uc == e[i].vc && !vis[e[i].vc]){
int x_self = fi(e[i].u), x_ano = fi(e[i].u + n);
int y_self = fi(e[i].v), y_ano = fi(e[i].v + n);
if(x_self == y_self){
--cnt;
vis[e[i].uc] = 1;
} else{
unite(e[i].u, e[i].v + n);
unite(e[i].v, e[i].u + n);
}
}
}
}
void build(int &i, int l, int r){
i = ++tot;
if(l == r){
t[i].fa = l;
t[i].ran = 0;
return;
}
int mid = (l + r) >> 1;
build(t[i].lc, l, mid);
build(t[i].rc, mid + 1, r);
}
void modify(int &i, int k, int p, int l, int r){
t[++tot] = t[i];
i = tot;
if(l == r){
t[i].fa = k;
return;
}
int mid = (l + r) >> 1;
if(p <= mid) modify(t[i].lc, k, p, l, mid);
else modify(t[i].rc, k, p, mid + 1, r);
}
int query(int i, int p, int l, int r){
if(l == r) return i; // 返回节点编号
int mid = (l + r) >> 1;
if(p <= mid) return query(t[i].lc, p, l, mid);
else return query(t[i].rc, p, mid + 1, r);
}
//并查集
void add(int i, int p, int l, int r){
if(l == r){
++t[i].ran;
return;
}
int mid = (l + r) >> 1;
if(p <= mid) add(t[i].lc, p, l, mid);
else add(t[i].rc, p, mid + 1, r);
}
int find(int v, int x){
int ff = query(v, x, 1, n + n);
if(t[ff].fa == x) return ff;
return find(v, t[ff].fa);
}
void uni(int &v, int x, int y){
x = find(v, x), y = find(v, y);
if(t[x].ran < t[y].ran){
modify(v, t[y].fa, t[x].fa, 1, n + n);
} else{
modify(v, t[x].fa, t[y].fa, 1, n + n);
if(t[x].ran == t[y].ran){
add(v, t[x].fa, 1, n + n);
}
}
}
void solve_2(){
tot = 0;
build(rt[0], 1, n + n);
int ver = 0, his;
for(int i = 1 ; i <= m ; ++i){ // 先把同色的合法并查集连接起来
if(e[i].uc == e[i].vc && !vis[e[i].uc]){
++ver;
rt[ver] = rt[ver - 1];
uni(rt[ver], e[i].u, e[i].v + n);
uni(rt[ver], e[i].v, e[i].u + n);
}
}
// for(int i = 1 ; i <= n ; ++i){
// int x = find(rt[ver], i);
// printf("%d ", t[x].fa);
// } printf("
");
his = ver;
int flag;
for(int i = 1 ; i <= m ; ++i){
if(e[i].uc == e[i].vc || vis[e[i].uc] || vis[e[i].vc]) continue;
++ver;
if(!(e[i].uc == e[i - 1].uc && e[i].vc == e[i - 1].vc)){
rt[ver] = rt[his];
flag = 0;
} else {
if(flag) continue;
rt[ver] = rt[ver - 1];
}
int p_self = find(rt[ver], e[i].u), p_ano = find(rt[ver], e[i].u + n);
int q_self = find(rt[ver], e[i].v), q_ano = find(rt[ver], e[i].v + n);
if(t[p_self].fa == t[q_self].fa){
--ans;
flag = 1;
} else {
uni(rt[ver], e[i].u, e[i].v + n);
uni(rt[ver], e[i].v, e[i].u + n);
}
}
}
ll read(){ ll s=0,f=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') f=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ s=s*10+ch-'0'; ch=getchar(); } return s*f; }
int main(){
n = read(), m = read(), k = read(); cnt = k;
for(int i = 1 ; i <= n ; ++i) c[i] = read();
for(int i = 1 ; i <= m ; ++i){
e[i].u = read(), e[i].v = read();
e[i].uc = c[e[i].u], e[i].vc = c[e[i].v];
if(e[i].uc > e[i].vc){
swap(e[i].u, e[i].v);
swap(e[i].uc, e[i].vc);
}
}
sort(e + 1, e + 1 + m);
// for(int i = 1 ; i <= m ; ++i){
// printf("%d %d %d %d
", e[i].u, e[i].v, e[i].uc, e[i].vc);
// }
solve_1();
ans = 1ll * (cnt - 1) * cnt / 2;
// printf("%d
", ans);
solve_2();
printf("%lld
", ans);
return 0;
}