zoukankan      html  css  js  c++  java
  • 高维前缀和总结(sosdp)

    前言

    今天中午不知怎么的对这个东西产生了兴趣,感觉很神奇,结果花了一个中午多的时间来看QAQ
    下面说下自己的理解。

    高维前缀和一般解决这类问题:

    对于所有的(i,0leq ileq 2^n-1),求解(sum_{jsubset i}a_j)

    显然,这类问题可以直接枚举子集求解,但复杂度为(O(3^n))。如果我们施展高维前缀和的话,复杂度可以到(O(ncdot 2^n))

    说起来很高级,其实代码就三行:

    for(int j = 0; j < n; j++) 
        for(int i = 0; i < 1 << n; i++)
            if(i >> j & 1) f[i] += f[i ^ (1 << j)];
    

    相信大家一开始学的时候就感觉很神奇,这是什么东西,这跟前缀和有什么关系?
    好吧,其实看到后面就知道了。

    正文

    二维前缀和

    一维前缀和就不说了,一般我们求二维前缀和时是直接容斥来求的:

    [sum_{i,j}=sum_{i-1,j}+sum_{i,j-1}-sum_{i-1,j-1}+a_{i,j} ]

    但还有一种求法,就是一维一维来求,也可以得到二维前缀和:

    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            a[i][j] += a[i - 1][j];
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            a[i][j] += a[i][j - 1];
    

    模拟一下就很清晰了。

    三维前缀和

    同二位前缀和,我们也可以对每一维分别来求:

    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            for(int k = 1; k <= n; k++) 
                a[i][j][k] += a[i - 1][j][k];
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            for(int k = 1; k <= n; k++)
                a[i][j][k] += a[i][j - 1][k];
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            for(int k = 1; k <= n; k++)
                a[i][j][k] += a[i][j][k - 1];
    

    高维前缀和

    接下来就步入正题啦。
    求解高维前缀和的核心思想也就是一维一维来处理,可以类比二维前缀和的求法稍微模拟一下。
    具体来说代码中的(f[i] = f[i] + f[i xor (1 << j)]),因为我们是正序枚举,所以(i xor (1 << j))在当前层,而(i)还在上层,所以我们将两个合并一下就能求出当前层的前缀和了QAQ。
    然后...就完了,好像没什么好说的。

    应用

    • 子集

    那这跟子集有啥关系?在二进制表示中,发现当(isubset j)时,其实这存在一个偏序关系,对于每一位都是这样。而我们求出的前缀和就是满足这个偏序关系的。
    回到开始那个问题,初始化(f[i]=a_i),直接求高维前缀和,那么最终得到的(f)就是答案数组了。

    • 超集

    理解了子集过后,我们将二进制中的每一个(1)当作(0)对待,(0)当作(1)对待求出来的就是超集了~相当于从另一个角出发来求前缀和。
    求超集代码如下:

    for(int j = 0; j < n; j++) 
        for(int i = 0; i < 1 << n; i++)
            if(!(i >> j & 1)) f[i] += f[i ^ (1 << j)];
    

    似乎(FMT)(快速莫比乌斯变换)就是借助高维前缀和这个东西来实现的。
    虽然只有三行代码,但很神奇QAQ

    upd:
    这个东西其实和(sos dp)是一个东西,但感觉用(dp)的思想去理解要稍微好一些,就再说一下(dp)的想法。

    • 子集

    我们还是来求子集,定义(dp_{i,mask})为处理了状态为(mask),二进制最后(i)位的子集信息时的和。
    那么我们枚举(i+1)位时,若当前(mask)这一位为(1),那么就从(dp_{i,mask},dp_{i,mask-(1<<i)})转移过来,分别代表有当前这一位时的子集或者没这一位时的子集,合并一下即可;若当前这位不为(1),就从(dp_{i,mask})转移过来。
    最后在代码中我们一般习惯滚动掉一维。

    • 超集

    类似地,定义(dp_{i,mask})为当前状态为(mask),处理了后(i)位的超集信息时的和。
    然后枚举第(i+1)位,若当前这一位为(0),就从(dp_{i,mask},dp_{i,mask+(1<<(i+1))})转移;若当前这位为(1),就直接从(dp_{i,mask})转移过来。

    例题

    arc 100E

    题意:
    给出(2^n)个数:(a_0,a_1,cdots,a_{2^n-1})
    之后对于(1leq kleq 2^n-1),求出:(a_i+a_j)的最大值,同时(i or jleq k)

    思路:
    挺奇妙的一个题,需要将问题转换。

    • 发现我们可以对每个(k),求出最大的(a_i+a_j)并且满足(i or j=k),最后答案就为一个前缀最大值。
    • 但这种形式也不好处理,我们可以将问题进一步转化为(i or jsubset k)。那么我们就将问题转化为了子集问题。
    • 所以接下来就对于每个(k),求出其所有子集的最大值和次大值就行了。
    • 直接枚举子集复杂度显然不能忍受,其实直接上高位前缀和搞一下就行~

    注意一下细节,集合中一开始有一个数。
    代码如下:

    #include <bits/stdc++.h>
    #define MP make_pair
    #define fi first
    #define se second
    #define sz(x) (int)(x).size()
    #define INF 0x3f3f3f3f3f
    //#define Local
    using namespace std;
    typedef long long ll;
    typedef pair<int, int> pii;
    const int N = 20;
    
    int n;
    pii a[1 << N];
    
    pii merge(pii A, pii B) {
        if(A.fi < B.fi) swap(A, B);
        pii ans = A;
        if(B.fi > ans.se) ans.se = B.fi;
        return ans;
    }
    
    void run() {
        for(int i = 0; i < 1 << n; i++) {
            int x; cin >> x;
            a[i] = MP(x, -INF);
        }
        for(int j = 0; j < n; j++) {
            for(int i = 0; i < 1 << n; i++) {
                if(i >> j & 1) a[i] = merge(a[i], a[i ^ (1 << j)]);
            }
        }
        int ans = 0;
        for(int i = 1; i < 1 << n; i++) {
            ans = max(ans, a[i].fi + a[i].se);
            cout << ans << '
    ';
        }
    }
    
    int main() {
        ios::sync_with_stdio(false);
        cin.tie(0); cout.tie(0);
        cout << fixed << setprecision(20);
    #ifdef Local
        freopen("../input.in", "r", stdin);
        freopen("../output.out", "w", stdout);
    #endif
        while(cin >> n) run();
        return 0;
    }
    
    

    cf1208F
    题意:
    给出序列(a_{1,2cdots,n},nleq 10^6)
    现在要找最大的(a_i|(a_j& a_k)),其中((i,j,k))满足(i<j<k)

    思路:

    • 显然我们可以枚举(a_i),那么问题就转换为如何快速找(a_j& a_k)
    • 因为最后要使得结果最大,我们二进制从高到底枚举时肯定是贪心来考虑的:即如果有两个数他们的与在这一位为(1),那么最后的答案中一定有这一位。
    • 那么我们逐位考虑,并且考虑是否有两个在右边的数他们"与"的结果为当前答案的超集即可,有的话答案直接加上这一位。
    • 那么可以用(sos dp)处理超集的信息,维护在最右端的两个位置,之后贪心来处理即可。

    代码如下:

    /*
     * Author:  heyuhhh
     * Created Time:  2020/2/27 10:51:39
     */
    #include <iostream>
    #include <algorithm>
    #include <cstring>
    #include <vector>
    #include <cmath>
    #include <set>
    #include <map>
    #include <queue>
    #include <iomanip>
    #define MP make_pair
    #define fi first
    #define se second
    #define sz(x) (int)(x).size()
    #define all(x) (x).begin(), (x).end()
    #define INF 0x3f3f3f3f
    #define Local
    #ifdef Local
      #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
      void err() { std::cout << '
    '; }
      template<typename T, typename...Args>
      void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
    #else
      #define dbg(...)
    #endif
    void pt() {std::cout << '
    '; }
    template<typename T, typename...Args>
    void pt(T a, Args...args) {std::cout << a << ' '; pt(args...); }
    using namespace std;
    typedef long long ll;
    typedef pair<int, int> pii;
    //head
    const int N = 2e6 + 5;
     
    int n;
    int a[N];
    pii dp[N];
     
    void add(int x, int id) {
        if(dp[x].fi == -1) {
            dp[x].fi = id;   
        } else if(dp[x].se == -1) {
            if(dp[x].fi == id) return;
            dp[x].se = id;   
            if(dp[x].fi < dp[x].se) swap(dp[x].fi, dp[x].se);
        } else if(dp[x].fi < id) {
            dp[x].se = dp[x].fi;
            dp[x].fi = id;   
        } else if(dp[x].se < id) {
            if(dp[x].fi == id) return;
            dp[x].se = id;
        }
    }
     
    void merge(int x1, int x2) {
        add(x1, dp[x2].fi);
        add(x1, dp[x2].se);
    }
     
    void run() {
        memset(dp, -1, sizeof(dp));
        cin >> n;
        for(int i = 1; i <= n; i++) {
            cin >> a[i];
            add(a[i], i);
        }
        for(int i = 0; i < 21; i++) {
            for(int j = 0; j < N; j++) if(j >> i & 1) {
                merge(j ^ (1 << i), j);
            }
        }
        int ans = 0;
        for(int i = 1; i <= n - 2; i++) {
            int lim = (1 << 21) - 1;
            int cur = a[i] ^ lim, res = 0;
            for(int j = 20; j >= 0; j--) if(cur >> j & 1) {
                if(dp[res ^ (1 << j)].se > i) {
                    res ^= (1 << j);   
                }
            }
            ans = max(ans, res | a[i]);
        }
        cout << ans << '
    ';
    }
     
    int main() {
        ios::sync_with_stdio(false);
        cin.tie(0); cout.tie(0);
        cout << fixed << setprecision(20);
        run();
        return 0;
    }
    
  • 相关阅读:
    Leetcode 191.位1的个数 By Python
    反向传播的推导
    Leetcode 268.缺失数字 By Python
    Leetcode 326.3的幂 By Python
    Leetcode 28.实现strStr() By Python
    Leetcode 7.反转整数 By Python
    Leetcode 125.验证回文串 By Python
    Leetcode 1.两数之和 By Python
    Hdoj 1008.Elevator 题解
    TZOJ 车辆拥挤相互往里走
  • 原文地址:https://www.cnblogs.com/heyuhhh/p/11585358.html
Copyright © 2011-2022 走看看