题目链接:最长不下降子序列问题
这一个问题虽然有三小问,但是每一个小问题的连接非常紧密
对于第一问,直接(O(n^2))水过,你要用(O(nlogn))当然也可以啊
二三问考虑使用网络流求解
我们利用第一问中得到的dp关系来建图:
很明显的是这里的每一个数只能用一次,所以我们将每一个点拆成两个点,之间连一条容量为1的边,其中一个点作为入点,另一个点作为出点(这一点十分重要,是网络流中一个常规且重要的操作)
dp的起始点也就是(dp[i]=1)时,我们从(s)向(i)连一条容量为1的边
dp的终点也就是(dp[i]=s)时,我们从(i)向(t)连一条容量为1的边
在中间出现转移的时候(也就是满足(dp[j]+1=dp[i] && a[j]<=a[i])时),我们从(j)向(i)连一条容量为1的边
这样直接去跑网络流的话第二问就解决了
对于第三问,最直接的思路就是对这个图进行改造或重建,由于此时(x_1)和(x_n)可以使用多次,我们可以将(1)中两个点所连的边的容量以及(s)到1的边的容量改成INF。
同时如果(n)与(t)有连边的话就同样进行上述操作
一个小技巧是:我们并不需要重新建图,由于我们并未该小流量,我们可以直接加上上面的几条边到残量网络中,然后继续跑网络流,将两次的答案相加就是第三问的答案了
#include<iostream>
#include<string>
#include<string.h>
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<map>
using namespace std;
#define maxd 1e9+7
struct network_flows{
struct node{
int from,to,nxt,flow;
}sq[100100];
int all,dep[100100],head[100100],cur[100100],n,m,s,t;
bool vis[100100];
void init(int n)
{
this->s=2*n+1;this->t=2*n+2;
this->n=2*n+2;this->all=1;
memset(head,0,sizeof(head));
}
void add(int u,int v,int w)
{
all++;sq[all].from=u;sq[all].to=v;sq[all].nxt=head[u];sq[all].flow=w;head[u]=all;
all++;sq[all].from=v;sq[all].to=u;sq[all].nxt=head[v];sq[all].flow=0;head[v]=all;
}
bool bfs()
{
queue<int> q;int i;
memset(vis,0,sizeof(vis));
vis[s]=1;q.push(s);dep[s]=0;
while (!q.empty())
{
int u=q.front();q.pop();
for (i=head[u];i;i=sq[i].nxt)
{
int v=sq[i].to;
if ((!vis[v]) && (sq[i].flow))
{
vis[v]=1;dep[v]=dep[u]+1;q.push(v);
}
}
}
if (!vis[t]) return 0;
for (i=1;i<=n;i++) cur[i]=head[i];
return 1;
}
int dfs(int now,int to,int lim)
{
if ((!lim) || (now==to)) return lim;
int i,sum=0;
for (i=head[now];i;i=sq[i].nxt)
{
int v=sq[i].to;
if (dep[now]+1==dep[v])
{
int f=dfs(v,to,min(lim,sq[i].flow));
if (f)
{
lim-=f;sum+=f;
sq[i].flow-=f;
sq[i^1].flow+=f;
if (!lim) break;
}
}
}
return sum;
}
int work()
{
int ans=0;
while (bfs()) ans+=dfs(s,t,maxd);
return ans;
}
}dinic;
int n,a[1010],dp[1010],s,t,ans1;
int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
void init()
{
n=read();int i,j;
dinic.init(n);
for (i=1;i<=n;i++) a[i]=read();
for (i=1;i<=n;i++)
{
int best=0;
for (j=1;j<i;j++)
if ((a[j]<=a[i]) && (dp[j]>dp[best])) best=j;
//cout << i << " " << best << endl;
dp[i]=dp[best]+1;
}
for (i=1;i<=n;i++) ans1=max(dp[i],ans1);
printf("%d
",ans1);
}
void make_sq()
{
s=2*n+1;t=2*n+2;int i,j;
//for (i=1;i<=n;i++) printf("%d ",dp[i]);cout << endl;
for (i=1;i<=n;i++)
{
dinic.add(i,i+n,1);
if (dp[i]==1) dinic.add(s,i,1);
if (dp[i]==ans1) dinic.add(i+n,t,1);
for (j=1;j<i;j++)
if ((a[i]>=a[j]) && (dp[i]==dp[j]+1)) dinic.add(j+n,i,1);
}
}
void work()
{
int ans2=dinic.work();
dinic.add(s,1,maxd);dinic.add(1,n+1,maxd);
if (dp[n]==ans1) {dinic.add(n+n,t,maxd);dinic.add(n,n+n,maxd);}
int ans3=dinic.work();
printf("%d
%d
",ans2,ans2+ans3);
}
int main()
{
init();
make_sq();
work();
return 0;
}