zoukankan      html  css  js  c++  java
  • CF739E Gosha is hunting(费用流/凸优化dp)

    纪念合格考爆炸。

    其实这个题之前就写过博客了,qwq但是不小心弄丢了,所以今天来补一下。

    首先,一看到球的个数的限制,不难相当用网络流的流量来限制每个球使用的数量。

    由于涉及到最大化期望,所以要使用最大费用最大流。

    我们新建两个点(ss,tt),分别表示两种球。

    那么我们现在考虑应该怎么计算期望呢。

    首先,如果假设如果对于一个怪物用一个球,那么连边也就比较容易了
    对于一个怪物(x)
    我们(ss -> x),费用为(p[i]),流量为1。表示一个球在一个怪物上只能用一次。
    (tt)也是同理。

    然后对于每一个(x->t),费用是(0),流量是(1),表示一个怪物只能用一个球。

    但是,要是每次不要求只能用一个球应该怎么做呢。

    我们考虑,这条边的费用应该是多少。

    两个球都用的期望应该是(1-(1-p_i)(1-q_i))
    经过展开,我们发现应该是(p_i+q_i-p_i imes q_i)

    那么由于我们发现,由于用了两个球,所以已经获得了二者之和的收益,那么在这一侧,只需要在上述建图的基础上(x->t),费用是(-p_i imes q_i)即可。

    最后跑一发最大费用最大流就能通过这个题qwq时间复杂度玄学。

    #include<bits/stdc++.h>
    #define pb push_back
    #define mk make_pair
    #define ll long long
    #define db double
    
    using namespace std;
    
    inline int read()
    {
       int x=0,f=1;char ch=getchar();
       while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
       while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
       return x*f;
    }
    
    const int maxn = 4010;
    const int maxm = 3e6+1e2;
    const double eps = 1e-10;
    
    int point[maxn],nxt[maxm],to[maxm],pre[maxm],from[maxn];
    double dis[maxn];
    int vis[maxn];
    double cost[maxm];
    int flow[maxm];
    double ans;
    int n,m,cnt=1;
    int s,t;
    
    void addedge(int x,int y,db w,int f)
    {
        nxt[++cnt]=point[x];
        pre[cnt]=x;
        to[cnt]=y;
        cost[cnt]=w;
        flow[cnt]=f;
        point[x]=cnt;
    }
    
    void insert(int x,int y,db w,int f)
    {
        addedge(x,y,w,f);
        addedge(y,x,-w,0);
    }
    
    queue<int> q;
    
    bool spfa(int s)
    {
        for (int i=1;i<=maxn-3;i++) dis[i]=-1e9;
        memset(vis,0,sizeof(vis));
        q.push(s);
        dis[s]=0;
        while (!q.empty())
        {
            int x = q.front();
            q.pop();
            vis[x]=0;
            for (int i=point[x];i;i=nxt[i])
            {
                int p = to[i];
                if (dis[p]-(dis[x]+cost[i])<-eps && flow[i]>0)
                {
                    from[p]=i;
                    dis[p]=dis[x]+cost[i];
                    if (!vis[p])
                    {
                        q.push(p);
                        vis[p]=1;
                    }
                }
            }
        }
        if (dis[t]==-1e9) return false;
        return true;
    }
    
    void mcf()
    {
        double x = 1e9;
        for (int i=from[t];i;i=from[pre[i]]) x=min(x,1.0*flow[i]);
        for (int i=from[t];i;i=from[pre[i]])
        {
            flow[i]-=x;
            flow[i^1]+=x;
            ans+=x*cost[i];
        }
    }
    
    void solve()
    {
        while (spfa(s)) mcf();
    }
    
    db a[maxn],b[maxn];
    int ss,sss;
    int aa,bb;
    
    int main()
    {
       n=read(),aa=read(),bb=read();
       s=maxn-10;
       ss=s+1;
       t=s+3;
       sss=ss+1;
       insert(s,ss,0,aa);
       insert(s,sss,0,bb);
       for (int i=1;i<=n;i++) scanf("%lf",&a[i]);
       for (int i=1;i<=n;i++) scanf("%lf",&b[i]);
       for (int i=1;i<=n;i++)
       {
       	  insert(ss,i,a[i],1);
       	  insert(sss,i,b[i],1);
       	  insert(i,t,0,1);
       	  insert(i,t,-a[i]*b[i],1);
       }
       solve();
       printf("%.4lf
    ",ans);
       return 0;
    }
    
    

    但是其实这个题的正解是凸优化(dp)

    首先,先做一个最(naive)的想法。

    我们令(dp[i][j][k])表示前(i)个怪物,用了(j)一号球,用了(k)个二号球

    那么转移也是显然的。
    每次只需要讨论一下对于当前的怪物是用几个球,用哪个即可。

    但是这样的复杂度是(O(n^3))的。
    显然没有办法通过。

    考虑怎么优化。

    由于题目中涉及到的正好用几个球,并且通过打表发现函数是凸的,那么我们就可以直接用凸优化来优化掉一维。

    (其实是可以直接优化两个的,但是我太懒,所以没写。)

    我们对于当前二分的值,表示每选一个二号球,就可以多得到(mid)的期望。不限制选的个数。

    那么不难得到下面的这个转移式子。

    dp[i][j]=dp[i-1][j];
    dp[i][j]=max(dp[i][j],dp[i-1][j]+(ymh){bb[i],1});
    if (j)
    {
       	 dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){a[i],0});
       	 dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){both[i],1});
    }
    

    然后通过调整(mid),通过正好选到(k)个二号球。

    最后求一个(dp)数组,然后记得把贡献减去就行。

    时间复杂度(n^2log),非常优秀。

    (其实是如果精度太小会(WA),精度太大会(TLE)

    但是完全可以做到(nlog^2)的。

    给代码。

    #include<bits/stdc++.h>
    #define pb push_back
    #define mk make_pair
    #define ll long long
    #define db double 
    
    using namespace std;
    
    inline int read()
    {
       int x=0,f=1;char ch=getchar();
       while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
       while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
       return x*f;
    }
    
    const int maxn = 2010;
    const db eps = 1e-6;
    
    struct ymh{
        db val;
        int num;
        ymh operator + (const ymh &b) const
        {
            return (ymh){val+b.val,num+b.num};
        }
    };
    
    ymh dp[maxn][maxn];
    db a[maxn],b[maxn];
    int n;
    db l=-4,r=4;
    
    inline int dcmp(double x,double y) 
    {
      return x-y<-eps ? -1 : (x-y>eps ? 1 : 0);
    }
    
    inline ymh max(ymh a,ymh b)
    {
        int now = dcmp(a.val,b.val);
        if (now==0)
        {
            if (a.num<b.num) return a;
            else return b;
        }
        else
        {
            if(now==-1) return b;
            else return a;
        }
    }
    
    int numa,numb;
    db aa[maxn];
    db bb[maxn];
    db both[maxn];
    
    bool check(db lim)
    {
       for (int i=1;i<=n;i++) aa[i]=a[i];
       for (int i=1;i<=n;i++) bb[i]=b[i]+lim;
       for (int i=1;i<=n;i++) both[i]=1.0-(1.0-a[i])*(1.0-b[i])+lim;
       for (register int i=1;i<=n;++i)
       {
       	  for (register int j=0;j<=numa;++j)
       	  {
       	  	dp[i][j]=dp[i-1][j];
       	  	dp[i][j]=max(dp[i][j],dp[i-1][j]+(ymh){bb[i],1});
       	  	if (j)
       	  	{
       	  		dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){a[i],0});
       	  		dp[i][j]=max(dp[i][j],dp[i-1][j-1]+(ymh){both[i],1});
            }
        //	if (dp[i][j].num>numb) return false;
          }
       }
       return dp[n][numa].num<=numb;
    }
    
    int main()
    {
       n=read(),numa=read(),numb=read();
       for (int i=1;i<=n;i++) scanf("%lf",&a[i]);
       for (int i=1;i<=n;i++) scanf("%lf",&b[i]);
       double ans=0;
       while (r-l>=eps)
       {
       	  db mid = (l+r)/2;
       	 // memset(dp,0,sizeof(dp));
       	  
       	  if (check(mid)) l=mid,ans=mid;
       	  else r=mid;
       	  //printf("%.4lf %d
    ",mid,dp[n][numa].num);
       }
       //cout<<1<<endl; 
       //printf("%.4lf
    ",ans);
       //memset(dp,0,sizeof(dp));
       check(ans);
       //printf("")
       //printf("%.4lf %d
    ",dp[n][numa].val,dp[n][numa].num);
       printf("%.4lf",dp[n][numa].val-1.0*numb*ans); 
       return 0;
    }
    
    
  • 相关阅读:
    微软软件下载
    FTP主动连接与被动连接
    Linux下grep显示前后几行信息
    cacti 安装过程中“ERROR: 您的MySQL TimeZone 数据库未被填充. 请在继续之前填入此数据库.”
    Cacti安装详细步骤
    Linux 踢掉其他终端用户
    迁移设备存储报的错误及解决方式
    sql_mode :(STRICT_TRANS_TABLES与STRICT_ALL_TABLES 区别)
    Nginx日志按日期切割详解(按天切割)
    git pull冲突:commit your changes or stash them before you can merge.
  • 原文地址:https://www.cnblogs.com/yimmortal/p/10268996.html
Copyright © 2011-2022 走看看