zoukankan      html  css  js  c++  java
  • 算法笔记--斜率优化dp

    斜率优化是单调队列优化的推广

    用单调队列维护递增的斜率

    参考:https://www.cnblogs.com/ka200812/archive/2012/08/03/2621345.html

    以例1举例说明:

    转移方程为:dp[i] = min(dp[j] + (sum[i] - sum[j])^2 + C)

    假设k < j < i, 如果从j转移过来比从k转移过来更优

    那么 dp[j] + (sum[i] - sum[j])^2 + C < dp[k] + (sum[i] - sum[k])^2 + C

    dp[j] - dp[k] < (sum[i] - sum[k])^2 - (sum[i] - sum[j])^2

    dp[j] - dp[k] < -2*sum[i]*sum[k] + sum[k]*sum[k] + 2*sum[i]*sum[j] - sum[j]*sum[j]

    dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k] < 2*sum[i]*(sum[j] - sum[k])

    (dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k]) < 2*sum[i]

    我们观察不等式左边, 它是个斜率的形式, 自变量x为sum, 函数f(x)为dp + sum*sum

    我们记这个斜率为g[j, k] = (dp[j] - dp[k] + sum[j]*sum[j] - sum[k]*sum[k]) / (sum[j] - sum[k])

    说明1.如果g[j, k] < 2*sum[i] 表示对于dp[i], 从j转移过来比k更优, 反之k更优

    说明2.下面我们来考虑着怎么从解集去掉多余的元素, 可以证明可能存在某些元素,无论怎样都不会是最优的,可以去掉这些多余的元素

    假设k < j < i

    结论:如果g[i, j] < g[j, k], 那么j可以去掉

    证明:对于某个i, 如果g[i, j] < 2*sum[i], 那么i比j更优, 结论成立;

                             如果g[i, j] >= 2*sum[i], 那么g[j, k] > g[i, j] >= 2*sum[i], 那么k比j更优,结论成立. 

    证毕.

    所以如果把所有g[i, j] < g[j, k]的情况中(后面斜率比前面斜率小的情况)的j都去掉, 那么我们就得到相邻两个元素的斜率递增的状况

    如下图

    下面来说明怎么维护这个解集:

    用双端队列维护这个解集, 每次从后面加入元素时, 按照说明2的方式去掉多余元素,使的相邻元素之间构成的斜率保持单调

    每次从前面找答案, 由于斜率单调递增, 所以最后一个小于2*sum[i]就是最优的解, 因为这个位置之前的g[i, j]都小于2*sum,

    表示后面的比前面更优, 之后的g[i, j] 都大于2*sum, 表示前面的比后面更优, 所以这个点是极值点

    又因为sum[i]也具有单调性, 所以下一个极值点的位置肯定大于等于当前极值点, 所以当前极值点之前的都可以从双端队列中移出

    ps:所有说明中, k < j < i

    例题1:HDU - 3507

    思路:维护递增斜率g[i, j] = (dp[i] - dp[j] + sum[i]*sum[i] - sum[j]*sum[j]) / (sum[i] - sum[j]) 

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<bits/stdc++.h>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 5e5 + 10;
    int a[N], n, m;
    LL sum[N], dp[N];
    bool g(int k, int j, LL C) {
        return (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k]) <= C*(sum[j]-sum[k]);
    }
    bool gg(int k, int j, int i) {
        return (dp[i]-dp[j]+sum[i]*sum[i]-sum[j]*sum[j])*(sum[j]-sum[k]) <= (dp[j]-dp[k]+sum[j]*sum[j]-sum[k]*sum[k])*(sum[i]-sum[j]);
    }
    deque<int> q;
    int main() {
        while(~scanf("%d %d", &n, &m)) {
            for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i] = sum[i-1]+a[i];
            while(!q.empty()) q.pop_back();
            q.push_back(0);
            for (int i = 1; i <= n; ++i) {
                while(q.size() >= 2) {
                    int a = q.front();
                    q.pop_front();
                    int b = q.front();
                    if(g(a, b, 2*sum[i])) ;
                    else {
                        q.push_front(a);
                        break;
                    }
                }
                int j = q.front();
                dp[i] = dp[j] + (sum[i]-sum[j])*(sum[i]-sum[j])+m;
                while(q.size() >= 2) {
                    int b = q.back();
                    q.pop_back();
                    int a = q.back();
                    if(gg(a, b, i)) ;
                    else {
                        q.push_back(b);
                        break;
                    }
                }
                q.push_back(i);
            }
            printf("%lld
    ", dp[n]);
        }
        return 0;
    }
    View Code

    例题2:HDU - 1300

    思路:维护递增斜率g[i, j] = (dp[i] - dp[j]) / (sum[i] - sum[j]) 

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<bits/stdc++.h>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 100 + 10;
    int a[N], p[N], n, m, T;
    LL sum[N], dp[N];
    bool g(int k, int j, LL C) {
        return (dp[j]-dp[k]) <= C*(sum[j]-sum[k]);
    }
    bool gg(int k, int j, int i) {
        return (dp[i]-dp[j])*(sum[j]-sum[k]) <= (dp[j]-dp[k])*(sum[i]-sum[j]);
    }
    deque<int> q;
    int main() {
        scanf("%d", &T);
        while(T--) {
            scanf("%d", &n);
            for (int i = 1; i <= n; ++i) scanf("%d %d", &a[i], &p[i]), sum[i] = sum[i-1]+a[i];
            for (int i = n-1; i >= 1; --i) p[i] = min(p[i], p[i+1]);
            while(!q.empty()) q.pop_back();
            q.push_back(0);
            for (int i = 1; i <= n; ++i) {
                while(q.size() >= 2) {
                    int a = q.front();
                    q.pop_front();
                    int b = q.front();
                    if(g(a, b, p[i])) ;
                    else {
                        q.push_front(a);
                        break;
                    }
                }
                int j = q.front();
                dp[i] = dp[j] + (sum[i]-sum[j]+10)*p[i];
                while(q.size() >= 2) {
                    int b = q.back();
                    q.pop_back();
                    int a = q.back();
                    if(gg(a, b, i)) ;
                    else {
                        q.push_back(b);
                        break;
                    }
                }
                q.push_back(i);
            }
            printf("%lld
    ", dp[n]);
        }
        return 0;
    }
    View Code

    例题3:HDU - 2993

    思路:论文题,维护递增的斜率,居然卡读入,没意思

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<bits/stdc++.h>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 1e5 + 10;
    int n, k, a[N], q[N], head, tail;
    double sum[N];
    const int BUF = 25000000;
    char Buf[BUF],*buf=Buf;
    inline void read(int &a)
    {
        for(a=0;*buf<48;buf++);
        while(*buf>47) a=a*10+*buf++-48;
    }
    int main() {
        int tot = fread(Buf, 1, BUF, stdin);
        while(true) {
            if(buf-Buf+1 >= tot) break;
            read(n), read(k);
            for (int i = 1; i <= n; ++i) read(a[i]), sum[i] = sum[i-1]+a[i];
            head = tail = 0;
            q[tail++] = 0;
            double ans = 0;
            for (int i = k; i <= n; ++i) {
                while(head+1 < tail) {
                    int a = q[head];
                    head++;
                    int b = q[head];
                    if((sum[i]-sum[a])*(i-b) < (sum[i]-sum[b])*(i-a)) ;
                    else {
                        q[--head] = a;
                        break;
                    }
                }
                int x = q[head];
                ans = max(ans, (sum[i]-sum[x])/(i-x));
                x = i-k+1;
                while(head+1 < tail) {
                    int b = q[tail-1];
                    --tail;
                    int a = q[tail-1];
                    if((sum[x]-sum[b])*(x-a) < (sum[x]-sum[a])*(x-b));
                    else {
                        q[tail++] = b;
                        break;
                    }
                }
                q[tail++] = x;
            }
            printf("%.2f
    ", ans);
        }
        return 0;
    }
    View Code

    例题4:UVALive - 5097

    思路:去重后发现按宽度排序后,高度递减

    那么维护递增斜率:g[j, k] = (dp[j] - dp[k]) / (h[k] - h[j])

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<bits/stdc++.h>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 5e4 + 10;
    pii a[N];
    vector<pii> vc;
    int n, k, h[N], w[N];
    LL dp[105][N];
    deque<int> q[105];
    bool g(int id, int k, int j, LL C) {
        return (dp[id][j]-dp[id][k]) <= C*(h[k+1]-h[j+1]);
    }
    bool gg(int id, int k, int j, int i) {
        return (dp[id][i]-dp[id][j])*(h[k+1]-h[j+1]) <= (dp[id][j]-dp[id][k])*(h[j+1]-h[i+1]);
    }
    int main() {
        while(~scanf("%d %d", &n, &k)) {
            for (int i = 1; i <= n; ++i) scanf("%d %d", &a[i].fi, &a[i].se);
            sort(a+1, a+1+n);
            vc.clear();
            for (int i = n; i >= 1; --i) if(i == n || a[i].se > vc.back().se) vc.pb(a[i]);
            reverse(vc.begin(), vc.end());
            n = vc.size();
            for (int i = 0; i < n; ++i) w[i+1] = vc[i].fi, h[i+1] = vc[i].se;
            for (int i = 0; i <= k; ++i) while(!q[i].empty()) q[i].pop_back();
            q[0].push_back(0);
            for (int i = 0; i <= k; ++i) for (int j = 0; j <= n; ++j) dp[i][j] = 0x3f3f3f3f3f3f3f3f;
            dp[0][0] = 0;
            for (int i = 1; i <= n; ++i) {
                for (int j = 0; j < k; ++j) {
                    while(q[j].size() >= 2) {
                        int a = q[j].front();
                        q[j].pop_front();
                        int b = q[j].front();
                        if(g(j, a, b, w[i])) ;
                        else {
                            q[j].push_front(a);
                            break;
                        }
                    }
                    int x = q[j].front();
                    dp[j+1][i] = min(dp[j+1][i], dp[j][x] + w[i]*1LL*h[x+1]);
                    while(q[j].size() >= 2) {
                        int b = q[j].back();
                        q[j].pop_back();
                        int a = q[j].back();
                        if(gg(j, a, b, i)) ;
                        else {
                            q[j].push_back(b);
                            break;
                        }
                    }
                    q[j].push_back(i);
                }
            }
            LL ans = 1LL<<60;
            for (int i = 1; i <= k; ++i) ans = min(ans, dp[i][n]);
            printf("%lld
    ", ans);
        }
        return 0;
    }
    View Code

    例题5:HDU - 3045

    思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k) / (a[j+1]-a[k+1])

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<bits/stdc++.h>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 4e5 + 5;
    int n, k;
    LL a[N], sum[N], dp[N];
    bool g(int k, int j, LL C) {
        return dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k <= C*(a[j+1]-a[k+1]);
    }
    bool gg(int k, int j, int i) {
        return (dp[i]-dp[j]+sum[j]-sum[i]+a[i+1]*i-a[j+1]*j)*(a[j+1]-a[k+1]) <= (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k)*(a[i+1]-a[j+1]);
    }
    deque<int> q;
    int main() {
        while(~scanf("%d %d", &n, &k)) {
            for (int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
            sort(a+1, a+1+n);
            for (int i = 1; i <= n; ++i) sum[i] = sum[i-1]+a[i];
            while(!q.empty()) q.pop_back();
            dp[0] = 0;
            q.push_back(0);
            for (int i = k; i <= n; ++i) {
                while(q.size() >= 2) {
                    int a = q.front();
                    q.pop_front();
                    int b = q.front();
                    if(g(a, b, i)) ;
                    else {
                        q.push_front(a);
                        break;
                    }
                }
                int j = q.front();
                dp[i] = dp[j]+sum[i]-sum[j]-a[j+1]*1LL*(i-j);
                if(i-k+1 >= k) {
                    while(q.size() >= 2) {
                        int b = q.back();
                        q.pop_back();
                        int a = q.back();
                        if(gg(a, b, i-k+1)) ;
                        else {
                            q.push_back(b);
                            break;
                        }
                    }
                    q.push_back(i-k+1);
                }
            }
            printf("%lld
    ", dp[n]);
        }
        return 0;
    }
    View Code

    例题6:POJ - 1180

    思路:要单独算s的影响,因为有s的存在时间就不好算前缀和了,对于每次新的开始s的影响是s*suf[i]

    那么就是维护递增斜率:g[j, k] = (dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]) / (sum[j] - sum[k])

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<iostream>
    #include<cstdio>
    #include<cmath>
    #include<algorithm>
    #include<deque>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 1e4 + 5;
    int T[N], F[N], n, s;
    LL sum[N], suf[N], dp[N];
    bool g(int k, int j, LL C) {
        return dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]) <= C*(sum[j]-sum[k]);
    }
    bool gg(int k, int j, int i) {
        return (dp[i]-dp[j]+s*(suf[i+1]-suf[j+1]))*(sum[j]-sum[k]) <= (dp[j]-dp[k]+s*(suf[j+1]-suf[k+1]))*(sum[i]-sum[j]);
    }
    deque<int> q;
    int main() {
        scanf("%d", &n);
        scanf("%d", &s);
        for (int i = 1; i <= n; ++i) scanf("%d %d", &T[i], &F[i]);
        for (int i = 1; i <= n; ++i) sum[i] = sum[i-1] + F[i], T[i]+=T[i-1];
        for (int i = n; i >= 1; --i) suf[i] = suf[i+1] + F[i];
        q.push_back(0);
        for (int i = 1; i <= n; ++i) {
            while(q.size() >= 2) {
                int a = q.front();
                q.pop_front();
                int b = q.front();
                if(g(a, b, T[i])) ;
                else {
                    q.push_front(a);
                    break;
                }
            }
            int j = q.front();
            dp[i] = dp[j] + T[i]*(sum[i]-sum[j])+s*suf[j+1];
            while(q.size() >= 2) {
                int b = q.back();
                q.pop_back();
                int a = q.back();
                if(gg(a, b, i)) ;
                else {
                    q.push_back(b);
                    break;
                }
            }
            q.push_back(i);
        }
        printf("%lld
    ", dp[n]);
        return 0;
    }
    View Code

    例题7:POJ - 2018

    思路:同HDU-2993

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<iostream>
    #include<cstdio>
    #include<cmath>
    #include<algorithm>
    #include<deque>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 1e5 + 10;
    int n, f, a[N];
    LL sum[N];
    deque<int> q;
    bool g(int k, int j, int i) {
        return (sum[j]-sum[k])*(i-j) <= (sum[i]-sum[j])*(j-k);
    }
    int main() {
        scanf("%d %d", &n, &f);
        for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i]=sum[i-1]+a[i];
        q.push_back(0);
        LL ans = 0;
        for (int i = f; i <= n; ++i) {
            while(q.size() >= 2) {
                int a = q.front();
                q.pop_front();
                int b = q.front();
                if(g(a, b, i)) ;
                else {
                    q.push_front(a);
                    break;
                }
            }
            int x = q.front();
            ans = max(ans, (sum[i]-sum[x])*1000/(i-x));
            x = i+1-f;
            while(q.size() >= 2) {
                int b = q.back();
                q.pop_back();
                int a = q.back();
                if(!g(a, b, x)) ;
                else {
                    q.push_back(b);
                    break;
                }
            }
            q.push_back(x);
        }
        printf("%lld
    ", ans);
        return 0;
    }
    View Code

    例题8:POJ - 3709

    思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*j-a[k+1]*k) / (a[j+1]-a[k+1])

    代码:

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<iostream>
    #include<cstdio>
    #include<cmath>
    #include<algorithm>
    #include<deque>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head 
    
    const int N = 5e5 + 10;
    int a[N], n, k, T;
    LL sum[N], dp[N];
    LL dw(int k, int j) {
        return a[j+1]-a[k+1];
    }
    LL up(int k, int j) {
        return dp[j]-dp[k]+sum[k]-sum[j]+a[j+1]*1LL*j-a[k+1]*1LL*k;
    }
    LL g(int k, int j, LL C) {
        return up(k, j) <= C*dw(k, j);
    }
    LL gg(int k, int j, int i) {
        return up(j, i)*dw(k, j) <= up(k, j)*dw(j, i);
    }
    deque<int> q;
    int main() {
        scanf("%d", &T);
        while(T--) {
            scanf("%d %d", &n, &k);
            for (int i = 1; i <= n; ++i) scanf("%d", &a[i]), sum[i]=sum[i-1]+a[i];
            while(!q.empty()) q.pop_back();
            q.push_back(0);
            for (int i = k; i <= n; ++i) {
                while(q.size() >= 2) {
                    int a = q.front();
                    q.pop_front();
                    int b = q.front();
                    if(g(a, b, i));
                    else {
                        q.push_front(a);
                        break;
                    }
                }
                int x = q.front();
                dp[i] = dp[x]+sum[i]-sum[x]-a[x+1]*1LL*(i-x);
                x = i-k+1;
                if(x >= k) {
                    while(q.size() >= 2) {
                        int b = q.back();
                        q.pop_back();
                        int a = q.back();
                        if(gg(a, b, x)) ;
                        else {
                            q.push_back(b);
                            break;
                        }
                    }
                    q.push_back(x);
                }
            }
            printf("%lld
    ", dp[n]);
        }
        return 0;
    }
    View Code

    例题9:UVA - 12594

    思路:维护递增斜率:g[j, k] = (dp[j]-dp[k]+sum[k]-sum[j]-k*s[k]+j*s[j]) / (j-k),其中sum[i] = ∑(j-pos)*pos, s[i] = ∑pos

    #pragma GCC optimize(2)
    #pragma GCC optimize(3)
    #pragma GCC optimize(4)
    #include<bits/stdc++.h>
    using namespace std;
    #define y1 y11
    #define fi first
    #define se second
    #define pi acos(-1.0)
    #define LL long long
    //#define mp make_pair
    #define pb emplace_back
    #define ls rt<<1, l, m
    #define rs rt<<1|1, m+1, r
    #define ULL unsigned LL
    #define pll pair<LL, LL>
    #define pli pair<LL, int>
    #define pii pair<int, int>
    #define piii pair<pii, int>
    #define pdd pair<double, double>
    #define mem(a, b) memset(a, b, sizeof(a))
    #define debug(x) cerr << #x << " = " << x << "
    ";
    #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    //head
    
    const int N = 2e4 + 10, M = 505;
    const LL INF = 0x3f3f3f3f3f3f3f3f;
    int T, n, k, pos[26];
    LL sum[N], s[N], dp[M][N];
    char nm[N], pn[N];
    deque<int> q[M];
    LL up(int id, int k, int j) {
        return dp[id][j]-dp[id][k]+sum[k]-sum[j]-k*s[k]+j*s[j];
    }
    LL dw(int k, int j) {
        return j-k;
    }
    bool g(int id, int k, int j, LL C) {
        return up(id, k, j) <= C*dw(k, j);
    }
    bool gg(int id, int k, int j, int i) {
        return up(id, j, i)*dw(k, j) <= up(id, k, j)*dw(j, i);
    }
    int main() {
        scanf("%d", &T);
        for(int cs = 1; cs <= T; ++cs) {
            scanf("%s %d", pn, &k);
            scanf("%s", nm+1);
            n = strlen(nm+1);
            for (int i = 0; i < 26; ++i) pos[pn[i]-'a'] = i;
            for (int i = 1; i <= n; ++i) s[i] = s[i-1]+pos[nm[i]-'a'];
            for (int i = 1; i <= n; ++i) sum[i] = sum[i-1]+(i-1-pos[nm[i]-'a'])*1LL*pos[nm[i]-'a'];
            for (int i = 0; i <= k; ++i) while(!q[i].empty()) q[i].pop_back();
            dp[0][0] = 0;
            q[0].push_back(0);
            for (int i = 1; i <= n; ++i) {
                for (int j = 0; j < k; ++j) {
                    while(q[j].size() >= 2) {
                        int a = q[j].front();
                        q[j].pop_front();
                        int b = q[j].front();
                        if(g(j, a, b, s[i])) ;
                        else {
                            q[j].push_front(a);
                            break;
                        }
                    }
                    int x = q[j].front();
                    dp[j+1][i] = dp[j][x]+sum[i]-sum[x]-x*(s[i]-s[x]);
                }
                for (int j = 1; j <= k; ++j) {
                    while(q[j].size() >= 2) {
                        int b = q[j].back();
                        q[j].pop_back();
                        int a = q[j].back();
                        if(gg(j, a, b, i)) ;
                        else {
                            q[j].push_back(b);
                            break;
                        }
                    }
                    q[j].push_back(i);
                }
            }
            printf("Case %d: %lld
    ", cs, dp[k][n]);
        }
        return 0;
    }
    View Code

     

  • 相关阅读:
    线程池:第一章:线程池的底层原理
    实战:第一章:防止其他人通过用户的url访问用户私人数据
    java程序报错:Unable to open debugger port (127.0.0.1:63959): java.net.SocketException "socket closed",编译过来就是无法打开调试器端口,套接字已关闭
    面试:第十三章:中高级程序员面试
    队列:第一章:阻塞队列
    我的分享:第三章:SpringCould五大组件
    Linux系统:第十章:服务器环境搭建
    深入理解JUC:第六章:Semaphore信号灯
    编写高质量JS代码的68个有效方法(十三)
    前端构建之gulp与常用插件
  • 原文地址:https://www.cnblogs.com/widsom/p/9323394.html
Copyright © 2011-2022 走看看