zoukankan      html  css  js  c++  java
  • ZOJ 3494 BCD Code(AC自动机 + 数位DP)题解

    题意:每位十进制数都能转化为4位二进制数,比如9是1001,127是 000100100111,现在问你,在L到R(R <= $10^{200}$)范围内,有多少数字的二进制表达式不包含模式串。

    思路:显然这是一道很明显的数位DP + AC自动机的题目。但是你要是直接把数字转化为二进制,然后在Trie树上数位DP你会遇到一个问题,以为转化为二进制后,前导零变成了四位000,那么你在DP的时候还要考虑前4位是不是都是000那样就要重新跑Trie树,显然这样是很菜(不会)的。那么肯定是想办法要变成十进制跑Trie树。

    那我们就预处理出一个bcd[i][j]表示在Trie树上i节点走向数字j可不可行,这样就行了。

    代码:

    #include<set>
    #include<map>
    #include<queue>
    #include<cmath>
    #include<string>
    #include<cstdio>
    #include<vector>
    #include<cstring>
    #include <iostream>
    #include<algorithm>
    using namespace std;
    typedef long long ll;
    typedef unsigned long long ull;
    const int maxn = 2000 + 5;
    const int M = 50 + 5;
    const ull seed = 131;
    const int INF = 0x3f3f3f3f;
    const int MOD = 1000000009;
    int n, m;
    int bit[205], pos;
    ll dp[500][maxn];
    int bcd[maxn][10];
    struct Aho{
        struct state{
            int next[10];
            int fail, cnt;
        }node[maxn];
        int size;
        queue<int> q;
    
        void init(){
            size = 0;
            newtrie();
            while(!q.empty()) q.pop();
        }
    
        int newtrie(){
            memset(node[size].next, 0, sizeof(node[size].next));
            node[size].cnt = node[size].fail = 0;
            return size++;
        }
    
        void insert(char *s){
            int len = strlen(s);
            int now = 0;
            for(int i = 0; i < len; i++){
                int c = s[i] - '0';
                if(node[now].next[c] == 0){
                    node[now].next[c] = newtrie();
                }
                now = node[now].next[c];
            }
            node[now].cnt = 1;
    
        }
    
        void build(){
            node[0].fail = -1;
            q.push(0);
    
            while(!q.empty()){
                int u = q.front();
                q.pop();
                if(node[node[u].fail].cnt && u) node[u].cnt |= node[node[u].fail].cnt;
                for(int i = 0; i < 10; i++){
                    if(!node[u].next[i]){
                        if(u == 0)
                            node[u].next[i] = 0;
                        else
                            node[u].next[i] = node[node[u].fail].next[i];
                    }
                    else{
                        if(u == 0) node[node[u].next[i]].fail = 0;
                        else{
                            int v = node[u].fail;
                            while(v != -1){
                                if(node[v].next[i]){
                                    node[node[u].next[i]].fail = node[v].next[i];
                                    break;
                                }
                                v = node[v].fail;
                            }
                            if(v == -1) node[node[u].next[i]].fail = 0;
                        }
                        q.push(node[u].next[i]);
                    }
                }
            }
        }
    
        ll dfs(int pos, int st, bool Max, bool lead){
            if(pos == -1) return 1;
            if(!Max && !lead && dp[pos][st] != -1) return dp[pos][st];
            int top = Max? bit[pos] : 9;
            ll ans = 0;
            for(int i = 0; i <= top; i++){
                if(lead && i == 0 && pos != 0){
                    ans = (ans + dfs(pos - 1, 0, Max && i == top, lead && i == 0)) % MOD;
                    continue;
                }
                if(bcd[st][i] == -1) continue;
                ans = (ans + dfs(pos - 1, bcd[st][i], Max && i == top, lead && i == 0)) % MOD;
            }
            if(!Max && !lead) dp[pos][st] = ans;
            return ans;
        }
    
        ll solve(char *s){
            pos = 0;
            int len = strlen(s);
            for(int i = len - 1; i >= 0; i--){
                bit[pos++] = s[i] - '0';
            }
            return dfs(pos - 1, 0, true, true);
        }
    
        char num[10][5] = {"0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111", "1000", "1001"};
        void init_bcd(){
            memset(bcd, 0, sizeof(bcd));
            for(int i = 0; i < size; i++){
                for(int j = 0; j < 10; j++){
                    int v = i;
                    for(int k = 0; k < 4; k++){
                        v = node[v].next[num[j][k] - '0'];
                        if(node[v].cnt){
                            bcd[i][j] = -1;
                            break;
                        }
                    }
                    if(bcd[i][j] != -1) bcd[i][j] = v;
                }
            }
        }
    
    }ac;
    
    char s1[205], s2[205];
    int main(){
        int T;
        scanf("%d", &T);
        while(T--){
            memset(dp, -1, sizeof(dp));
            scanf("%d", &n);
            ac.init();
            for(int i = 0; i < n; i++){
                scanf("%s", s1);
                ac.insert(s1);
            }
            ac.build();
            ac.init_bcd();
    
            scanf("%s%s", s1, s2);
            int lens1 = strlen(s1);
            int pp = lens1 - 1;
            while(s1[pp] == '0'){
                s1[pp] = '9';
                pp--;
            }
            s1[pp]--;
            if(s1[0] == '0' && lens1 > 1){
                for(int i = 1; i < lens1; i++){
                    s1[i - 1] = s1[i];
                }
                s1[lens1 - 1] = '';
            }
    //        cout << s1 << endl;
            ll ans1 = ac.solve(s1);
            ll ans2 = ac.solve(s2);
            ll ans = ((ans2 - ans1) % MOD + MOD) % MOD;
            printf("%lld
    ", ans);
        }
        return 0;
    }
  • 相关阅读:
    Java多线程问题
    pattern-matching as an expression without a prior match -scala
    从Zero到Hero,OpenAI重磅发布深度强化学习资源
    What-does-git-remote-and-origin-mean
    flink source code
    如何生成ExecutionGraph及物理执行图
    rocketmq 源码
    Flink source task 源码分析
    flink 获取上传的Jar源码
    fileupload
  • 原文地址:https://www.cnblogs.com/KirinSB/p/11199789.html
Copyright © 2011-2022 走看看