zoukankan      html  css  js  c++  java
  • HDU5845 trie树优化dp

    http://acm.hdu.edu.cn/showproblem.php?pid=5845

    题意:给定序列,问最多可以分成多少段序列使得每段序列不超过L且异或和不超过X

    首先对于区间异或和,很容易想到前缀异或和去优化使其可以在O(1)时间内求出区间异或和,然后我们就可以写出一个n²暴力

    #include <map>
    #include <set>
    #include <ctime>
    #include <cmath>
    #include <queue>
    #include <stack>
    #include <vector>
    #include <string>
    #include <bitset>
    #include <cstdio>
    #include <cstdlib>
    #include <cstring>
    #include <sstream>
    #include <iostream>
    #include <algorithm>
    #include <functional>
    using namespace std;
    #define For(i, x, y) for(int i=x;i<=y;i++)  
    #define _For(i, x, y) for(int i=x;i>=y;i--)
    #define Mem(f, x) memset(f,x,sizeof(f))  
    #define Sca(x) scanf("%d", &x)
    #define Sca2(x,y) scanf("%d%d",&x,&y)
    #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
    #define Scl(x) scanf("%lld",&x)  
    #define Pri(x) printf("%d
    ", x)
    #define Prl(x) printf("%lld
    ",x)  
    #define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
    #define LL long long
    #define ULL unsigned long long  
    #define mp make_pair
    #define PII pair<int,int>
    #define PIL pair<int,long long>
    #define PLL pair<long long,long long>
    #define pb push_back
    #define fi first
    #define se second 
    typedef vector<int> VI;
    int read(){int x = 0,f = 1;char c = getchar();while (c<'0' || c>'9'){if (c == '-') f = -1;c = getchar();}
    while (c >= '0'&&c <= '9'){x = x * 10 + c - '0';c = getchar();}return x*f;}
    const double PI = acos(-1.0);
    const double eps = 1e-9;
    const int maxn = 1e5 + 10;
    const int INF = 0x3f3f3f3f;
    const int mod = 268435456; 
    LL N,X,L,P,Q;
    LL a[maxn],dp[maxn];
    LL pre[maxn];
    LL sum(int i,int j){
        return pre[j] ^ pre[i - 1];
    }
    int main(){
        int T; Sca(T);
        while(T--){
            scanf("%lld%lld%lld",&N,&X,&L);
            scanf("%lld%lld%lld",&a[1],&P,&Q);
            for(int i = 2; i <= N ; i ++){
                a[i] = ((a[i - 1] * P) + Q) % mod;
            }
            pre[0] = 0;
            for(int i = 1; i <= N ; i ++) pre[i] = pre[i - 1] ^ a[i];
            for(int i = 0; i <= N ; i ++) dp[i] = 0;
            for(int i = 1; i <= N; i ++){
                for(int j = max(0LL,i - L); j < i ; j ++){
                    if(sum(j + 1,i) <= X) dp[i] = max(dp[i],dp[j] + 1);
                }
            }
            Prl(dp[N]);
        }
        return 0;
    }
    n²暴力

    我们可以发现对于pre相同的下标而言,dp的大小呈单调性,即i > j 且pre[i] = pre[j] 则dp[i] > dp[j],由于i,j之间异或和为0,显然dp[i] - dp[j] >= 1

    那么对于前面长度L的区间,我们可以考虑用字典树优化,用01字典树维护每个前缀和的dp最大值,由于满足单调性,对于字典树上的删除我们只需要维护每个节点出现的次数,因为只要字典树上还存在当前节点(出现次数不为0),就意味着当前最大值不会变(最大值永远越后面的越大)

    对于查询的时候就需要讨论,如果当前位X为0,说明查询的pre当前位上也是0,需要走当前位与他相同的路径,如果X为1,那么可以走与当前位相反的路径使得该位和X一样为1,或者走与其相同的路径使得该位为0,倘若走0的路径,那么直接取子树的最大值不用继续往下走,因为下面无论怎么走都一定比X小

    #include <map>
    #include <set>
    #include <ctime>
    #include <cmath>
    #include <queue>
    #include <stack>
    #include <vector>
    #include <string>
    #include <bitset>
    #include <cstdio>
    #include <cstdlib>
    #include <cstring>
    #include <sstream>
    #include <iostream>
    #include <algorithm>
    #include <functional>
    using namespace std;
    #define For(i, x, y) for(int i=x;i<=y;i++)  
    #define _For(i, x, y) for(int i=x;i>=y;i--)
    #define Mem(f, x) memset(f,x,sizeof(f))  
    #define Sca(x) scanf("%d", &x)
    #define Sca2(x,y) scanf("%d%d",&x,&y)
    #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
    #define Scl(x) scanf("%lld",&x)  
    #define Pri(x) printf("%d
    ", x)
    #define Prl(x) printf("%lld
    ",x)  
    #define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
    #define LL long long
    #define ULL unsigned long long  
    #define mp make_pair
    #define PII pair<int,int>
    #define PIL pair<int,long long>
    #define PLL pair<long long,long long>
    #define pb push_back
    #define fi first
    #define se second 
    typedef vector<int> VI;
    int read(){int x = 0,f = 1;char c = getchar();while (c<'0' || c>'9'){if (c == '-') f = -1;c = getchar();}
    while (c >= '0'&&c <= '9'){x = x * 10 + c - '0';c = getchar();}return x*f;}
    const int maxn = 1e5 + 10;
    const int maxm = 5e6 + 10;
    const LL INF = 1e18;
    const LL mod = 268435456; 
    LL N,X,L,P,Q;
    LL a[maxn],dp[maxn],pre[maxn];
    int nxt[maxm][2],cnt,num[maxm];
    LL val[maxm];
    void insert(int j){
        LL x = pre[j],v = dp[j];
        int p = 1;
        for(int i = 32; i >= 0; i --){
            int id = (x >> i) & 1;
            if(!nxt[p][id]){
                 nxt[p][id] = ++cnt;
                 val[cnt] = -INF; num[cnt] = nxt[cnt][0] = nxt[cnt][1] = 0;
            }
            p = nxt[p][id];
            val[p] = max(val[p],v); num[p]++;
        }
    }
    void del(int p,int i,LL x){
        if(i == -1){if(!num[p]) val[p] = -INF;return;}
        int id = (x >> i) & 1;
        num[nxt[p][id]]--;
        del(nxt[p][id],i - 1,x);
        val[p] = val[nxt[p][id]];
        if(nxt[p][id ^ 1] && num[nxt[p][id ^ 1]] > 0) val[p] = max(val[nxt[p][0]],val[nxt[p][1]]);
    }
    LL query(LL x){
        int p = 1;
        LL ans = -INF;
        for(int i = 32; i >= 0 ; i --){
            int id = (x >> i) & 1;
            if((X >> i) & 1){
                if(nxt[p][id] && num[nxt[p][id]]){
                    ans = max(ans,val[nxt[p][id]]);
                }
                if(nxt[p][id ^ 1] && num[nxt[p][id ^ 1]]){
                    p = nxt[p][id ^ 1];
                }
            }else{
                if(!nxt[p][id] || !num[nxt[p][id]]) return ans;
                p = nxt[p][id];
            }
        }
        ans = max(ans,val[p]);
        return ans;
        return val[p];
    }
    int main(){
        int T; Sca(T); cnt = 1;
        while(T--){
            for(int i = 0 ; i <= cnt; i ++){val[i] = -INF; nxt[i][0] = nxt[i][1] = num[i] = 0;}
            scanf("%lld%lld%lld",&N,&X,&L); cnt = 1;
            scanf("%lld%lld%lld",&a[1],&P,&Q);
            for(int i = 2; i <= N ; i ++) a[i] = ((a[i - 1] * P) + Q) % mod;
            pre[0] = dp[0] = 0; insert(0);
            for(int i = 1; i <= N ; i ++) pre[i] = pre[i - 1] ^ a[i];
        //    For(i,1,N) cout << pre[i] << " ";
        //    cout << endl;
            for(int i = 1; i <= N; i ++){
                if(i - L - 1 >= 0 && dp[i - L - 1] >= 0) del(1,32,pre[i - L - 1]);
                dp[i] = query(pre[i]) + 1;
                if(dp[i] > 0) insert(i);
            }
            if(dp[N] < 0) dp[N] = 0;
            Prl(dp[N]);
        }
        return 0;
    }
  • 相关阅读:
    Python——数据结构——字典
    Python——print()函数
    Python数据结构——序列总结
    elasticsearch全文检索java
    elasticsearch单例模式连接 java
    【转载】信号中断 与 慢系统调用
    设计模式——状态模式(C++实现)
    设计模式——观察者模式(C++实现)
    C++调用C方法
    设计模式——外观模式(C++实现)
  • 原文地址:https://www.cnblogs.com/Hugh-Locke/p/11280544.html
Copyright © 2011-2022 走看看