zoukankan      html  css  js  c++  java
  • SDOI征途--斜率优化

    题目描述

    给定长为 n 的数列 a, 要求划分成 m 段,使得方差最小, 输出方差(*m^2)

    题解

    斜率优化好题

    准备部分

    设第 i 段长为 (len_i)
    先考虑方差((S^2))的式子:

    [S^2 = frac{1}{m}*sum_{i=1}^m(len_i - (frac{1}{m}*sum_{j=1}^{m}len_j) )^2 ]

    拆项得 -->

    [S^2 = frac{1}{m}sum_{i=1}^{m}len_i^2+frac{1}{m}sum_{i=1}^{m}frac{1}{m^2}sum_{j=1}^{m}len_j-2*frac{1}{m}*sum_{i=1}^{m}(len[i]*frac{1}{m}sum_{j=1}^{m}len_j) ]

    -->

    [S^2 = frac{1}{m}sum_{i=1}^{m}len_i^2+frac{1}{m^2}sum_{i=1}^{m}len_i^2-2*frac{1}{m^2}*(sum_{i=1}^{m}len_i)*(sum_{j=1}^{m}len_j) ]

    -->

    [S^2 = frac{1}{m}sum_{i=1}^{m}len_i^2-frac{1}{m^2}sum_{i=1}^{m}len_i^2 ]

    再把(m^2)乘进去

    [m*sum_{i=1}^{m}len_i^2-sum_{i=1}^{m}len_i^2 ]

    可发现$$-sum_{i=1}{m}len_i2$$ 这一坨是常数,也就是原序列和的平方

    然后开始DP

    设 f[i][k] 表示当前在第 i 个数,划分成 k 段
    转移时枚举第 k 段的起点 j+1 (终点是 i ,注意这里枚举的是j+1):

    [f[i][k] =f[j][k-1]+(sum_{l=j+1}^{i}a[l])^2 ]

    再用滚动数组g(也可不用)与前缀和sum记录一下 a[l] 优化就好
    也就是

    [f[i]=g[j]+(sum[i]-sum[j])^2 ]

    斜率优化

    把上面的DP转移方程拆开即得:

    [g[j]+sum[j]^2=-2*sum[i]*sum[j]+(sum[i]^2-f[i]) ]

    (g[j]+sum[j]^2) 看做 Y;
    (sum[j]^2) 看做 X;
    (-2*sum[i]) 看做 K;
    ((sum[i]^2-f[i])) 看做 B;
    然后用单调队列维护一下斜率递增的决策点就好

    代码

    #include<bits/stdc++.h>
    using namespace std;
    #define re register
    #define in inline
    #define get getchar()
    #define ll long long
    in int read()
    {
    	int t=0; char ch=get;
    	while(ch<'0' || ch>'9') ch=get;
    	while(ch<='9' && ch>='0') t=t*10+ch-'0', ch=get;
    	return t;
    }
    const int _=5001;
    int n,m;
    ll sum[_],f[_],g[_],q[_];
    #define db double
    in db  calc(int a,int b)
    {
    	ll Y1=g[a]+sum[a]*sum[a],Y2=g[b]+sum[b]*sum[b];
    	return 1.0*(Y1-Y2)/(sum[a]-sum[b]);
    }
    int main()
    {
    	n=read(),m=read();
    	for(re int i=1;i<=n;i++){
    		sum[i]=sum[i-1]+read();
    		g[i]=sum[i]*sum[i];
    	}
    	for(re int k=2;k<=m;k++)
    	{
    		int l=1,r=0;
    		q[l]=0;q[0]=0;
    		for(re int i=1;i<=n;i++)
    		{
    			while(l<r && calc(q[l],q[l+1]) < 2*sum[i]) l++;
    			int j=q[l];
    			f[i]=g[j]+(sum[i]-sum[j])*(sum[i]-sum[j]);
    			while(l<r && calc(q[r],i) < calc(q[r-1],i)) r--;
    			q[++r]=i;
    		}
    		memcpy(g,f,sizeof(g));
    	}
    	cout<<m*f[n]-sum[n]*sum[n]<<endl;
    }
    
    
  • 相关阅读:
    区别@ControllerAdvice 和@RestControllerAdvice
    Cannot determine embedded database driver class for database type NONE
    使用HttpClient 发送 GET、POST、PUT、Delete请求及文件上传
    Markdown语法笔记
    Property 'sqlSessionFactory' or 'sqlSessionTemplate' are required
    Mysql 查看连接数,状态 最大并发数(赞)
    OncePerRequestFilter的作用
    java连接MySql数据库 zeroDateTimeBehavior
    Intellij IDEA 安装lombok及使用详解
    ps -ef |grep xxx 输出的具体含义
  • 原文地址:https://www.cnblogs.com/yzhx/p/11779148.html
Copyright © 2011-2022 走看看