正题
题目链接:https://www.luogu.com.cn/problem/AT3968
题目大意
给出\(n\)个点的一棵树,然后给出\(m\)条路径。每条边的权值是它是否又被正向经过+是否又被反向经过,给每条路径定向使得所有边的权值和最大。输出方案。
\(1\leq n,m\leq 2000\)
解题思路
显然第一条路径我们可以随意定向,然后与它同向的且经过重复边的路径我们取反其中一条然后再以这条路径继续去取反别的相交路径。
这样下去会发现显然两条路径相交的部分都会被统计两次,也就是如果一条边被经过了两次那么这里的贡献就是\(2\),否则被经过一次就是\(1\),也就是答案能到达上界。
但是这样构造显然容易出问题,考虑一个更加系统的构造方法。
先把所有路径挂到两端点上,然后我们每次找到一个叶子节点,考虑这个叶子节点
- 没有路径:那么直接删除这个叶子。
- 有一条路径:那么这条路径的方向和这个叶子无关,直接把路径丢到父节点然后删除这个叶子。
- 有两条或者以上路径:此时我们指定任意两条边反向,但是这两条边谁正谁反还未确定,假设是\((x,a),(x,b)\),那我们新建一条路径\((a,b)\)(也就是这两条路径在分叉点后产生的新路径)然后这两条再根据这条路径确定方向。
用\(set\)维护端点的路径,删掉的路径记得在对应的节点也删掉,细节有点多慢慢调。
时间复杂度:\(O(n^2\log n)\)
code
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<set>
#include<queue>
#define mp(x,y) make_pair(x,y)
using namespace std;
const int N=2100,M=5100;
struct node{
int to,next;
}a[N<<1];
int n,m,tot,c[N],ls[N],in[N],fa[N];
int ps[N],pt[N],s[M],t[M];
queue<int> q;set<int> p[M];
pair<int,int> ans[M];
void addl(int x,int y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
void dfs(int x){
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa[x])continue;
fa[y]=x;dfs(y);in[x]++;
}
return;
}
bool work(int x,int z,int fr){
if(x==z)return 1;
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fr)continue;
int w=work(y,z,x);
if(w){
if(y==fa[x])c[x]++;
else c[y]++;
return 1;
}
}
return 0;
}
void given(int x){
int i=*p[x].begin();
p[x].erase(i);
p[fa[x]].insert(i);
if(i>0)s[i]=fa[x];
else t[-i]=fa[x];
i=abs(i);
if(s[i]==t[i]){
p[s[i]].erase(i);
p[s[i]].erase(-i);
}
return;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
addl(x,y);addl(y,x);
}
tot=m;dfs(1);int sum=0;
for(int i=1;i<=m;i++){
scanf("%d%d",&s[i],&t[i]);
ps[i]=s[i];pt[i]=t[i];
p[s[i]].insert(i);
p[t[i]].insert(-i);
work(s[i],t[i],0);
}
for(int i=2;i<=n;i++){
if(!in[i])q.push(i);
sum+=min(c[i],2);
}
while(!q.empty()){
int x=q.front();q.pop();
if(x==1)break;
in[fa[x]]--;
if(!in[fa[x]])q.push(fa[x]);
if(p[x].size()==1)given(x);
else if(p[x].size()>1){
++tot;
int i=*p[x].begin();p[x].erase(i);
int j=*p[x].begin();p[x].erase(j);
s[tot]=(i>0)?t[i]:s[-i];t[tot]=(j>0)?t[j]:s[-j];
if(s[tot]!=t[tot])p[s[tot]].insert(tot),p[t[tot]].insert(-tot);
p[s[tot]].erase(-i);p[t[tot]].erase(-j);
ans[abs(i)]=mp(tot,-i/abs(i));
ans[abs(j)]=mp(tot, j/abs(j));
while(!p[x].empty())given(x);
}
}
printf("%d\n",sum);
for(int i=tot;i>=1;i--){
if(ans[i].second!=0)
ans[i].first=ans[ans[i].first].first*ans[i].second;
else ans[i].first=1;
}
for(int i=1;i<=m;i++){
if(ans[i].first>0)printf("%d %d\n",ps[i],pt[i]);
else printf("%d %d\n",pt[i],ps[i]);
}
return 0;
}