zoukankan      html  css  js  c++  java
  • LOJ#2552. 「CTSC2018」假面(期望 背包)

    题意

    题目链接

    Sol

    多年以后,我终于把这题的暴力打出来了qwq 好感动啊。。

    刚开始的时候想的是:

    (f[i][j])表示第(i)轮, 第(j)个人血量的期望值

    转移的时候若要淦这个人,那么(f[i][j] = (f[i - 1][j] + 1) * p + (f[i - 1][j]) * (1 - p))

    然后发现自己傻逼了。。因为期望不能正着推。

    考虑直接推概率,设(t[k][i][j])表示第(k)轮,第(i)个人,血量为(j)的概率

    这玩意儿是可以转移的,就是判一下这次打中了没有

    第二问可以对每个点分别算答案,设(g[i][j])表示除必须活着的人外,前(i)个人中,有(j)个活着的概率,背包转移一下

    这样复杂度是(O(qn + n^3))

    显然第二问看起来非常暴力,

    标算的做法好像叫“退背包”,也就是从背包中删除一个元素

    先不考虑某个元素必须存活,推一遍得到(g[i][j])表示前(i)个人中,有(j)个存活的概率

    考虑转移的式子,设(ali[i])表示第(i)个人活着的概率

    (g[i][j] = g[i - 1][j - 1] * ali[i] + g[i - 1][j] * (1 - ali[i]))

    而我们要得到的实际上就是(g[i-1][j])这一项

    那么(g[i - 1][j] = frac{g[i][j] - g[i - 1][j - 1] * ali[i]}{1 - ali[i]})

    倒着推一遍即可,注意当(1 - ali[i] = 0)的时候需要特判,此时(g[i - 1][j] = g[i][j + 1])

    70分

    #include<bits/stdc++.h>
    //#define int long long 
    using namespace std;
    const int MAXN = 201, mod = 998244353;
    int f[2][MAXN], g[MAXN][MAXN], t[2][MAXN][MAXN];
    // f: expect 
    inline int read() {
        char c = getchar(); int x = 0, f = 1;
        while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
        while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
        return x * f;
    }
    int N, a[MAXN], Q, em[MAXN];
    int fp(int a, int p) {
        int base = 1;
        while(p) {
            if(p & 1) base = 1ll * base * a % mod;
            a = 1ll * a * a % mod; p >>= 1;
        }
        return base;
    }
    int inv(int a) {
        return fp(a, mod - 2);
    }
    int add(int x, int y) {
        if(x + y < 0) return x + y + mod;
        else return x + y >= mod ? x + y - mod : x + y;
    }
    int mul(int x, int y) {
        x = (x + mod) % mod; y = (y + mod) % mod;
        return 1ll * x * y % mod;
    }
    int solve(int id, int o, int N) {//这里dp的时候不能直接表示有j个活着,必须表示除i之外有j个活着。。 
        memset(g, 0, sizeof(g));
        g[0][0] = 1; 
        for(int i = 1; i <= N; i++) {
            for(int j = 0; j <= N; j++) {
                if(em[i] ^ id) {
                    g[i][j] = mul(g[i - 1][j], t[o][em[i]][0]);
                    if(j) g[i][j] = add(g[i][j], mul(g[i - 1][j - 1], 1 - t[o][em[i]][0]));
                }
                else g[i][j] = g[i - 1][j];
            }
        }
        int ans = 0;
        for(int i = 0; i < N; i++) 
            ans = add(ans, mul(mul(1 - t[o][id][0], g[N][i]), inv(i + 1)));
        return ans;
    }
    signed main() {
    //  freopen("a.in", "r", stdin);
    //  freopen("b.out", "w", stdout);
        N = read();
        for(int i = 1; i <= N; i++) a[i] = read(), t[0][i][a[i]] = 1, f[0][i] = a[i];
        Q = read();
        int o = 1;
        for(int i = 1; i <= Q; i++, o ^= 1) {
            int opt = read();
            memcpy(t[o], t[o ^ 1], sizeof(t[o]));
            if(opt == 0) {//
                int id = read(), u = read(), v = read(), p = 1ll * u * inv(v) % mod;
                t[o][id][0] = add(t[o][id][0], mul(p, t[o][id][1]));
                for(int j = 1; j <= a[id]; j++) t[o][id][j] = add(mul(p, t[o ^ 1][id][j + 1]), mul(1 - p, t[o ^ 1][id][j]));
            } else if(opt == 1) {
                int k = read(), cnt = 0;
                for(int i = 1; i <= k; i++) em[++cnt] = read();
                for(int i = 1; i <= k; i++) printf("%d ", solve(em[i], o, cnt)); puts("");
            }
        }
        for(int i = 1; i <= N; i++) {
            int ans = 0;
            for(int j = 1; j <= a[i]; j++) 
                ans = add(ans, mul(j, t[o ^ 1][i][j]));
            printf("%d ", ans);
            
        }
        return 0;
    }
    /*
    */
    

    100分

    #include<bits/stdc++.h>
    //#define int long long 
    using namespace std;
    const int MAXN = 201, mod = 998244353;
    int f[2][MAXN], g[MAXN][MAXN], t[2][MAXN][MAXN];
    // f: expect 
    inline int read() {
        char c = getchar(); int x = 0, f = 1;
        while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
        while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
        return x * f;
    }
    int N, a[MAXN], Q, em[MAXN], ans[MAXN], ali[MAXN], tp[MAXN], Inv[MAXN];
    int add(int x, int y) {
        if(x + y < 0) return x + y + mod;
        else return x + y >= mod ? x + y - mod : x + y;
    }
    int mul(int x, int y) {
        x = (x + mod) % mod; y = (y + mod) % mod;
        return 1ll * x * y % mod;
    }
    int fp(int a, int p) {
        int base = 1;
        while(p) {
            if(p & 1) base = 1ll * base * a % mod;
            a = 1ll * a * a % mod; p >>= 1;
        }
        return base;
    }
    int inv(int a) {
        a = add(a, mod);
        return fp(a, mod - 2);
    }
    
    void Pre(int o, int N) {
    //  memset(g, 0, sizeof(g));
        g[0][0] = 1; 
        for(int i = 1; i <= N; i++) {
            ali[i] = (1 - t[o][em[i]][0] + mod) % mod;//alive
            for(int j = 0; j <= i; j++) {
                g[i][j] = mul(g[i - 1][j], t[o][em[i]][0]);
                if(j) g[i][j] = add(g[i][j], mul(g[i - 1][j - 1], ali[i]));
            }
        }   
    }
    int solve(int id, int o, int N) {
        //memset(tp, 0, sizeof(tp));
        if(!ali[id]) return 0;
        if(ali[id] == 1) {
            for(int i = 1; i <= N; i++) tp[i - 1] = g[N][i];
        } else {
            int down = inv(1 - ali[id]);
            tp[0] = mul(g[N][0], down);
            for(int i = 1; i <= N; i++) 
                tp[i] = mul(g[N][i] - mul(tp[i - 1], ali[id]), down);
        }
    
        int ans = 0;
        for(int i = 1; i <= N; i++) 
            ans = add(ans, mul(mul(ali[id], tp[i - 1]), Inv[i]));
        return ans;
    }
    signed main() {
        //freopen("faceless10.in", "r", stdin);
    //  freopen("b.out", "w", stdout);
        N = read();
        for(int i = 1; i <= N; i++) a[i] = read(), t[0][i][a[i]] = 1, f[0][i] = a[i], Inv[i] = inv(i);
        Q = read();
        int o = 1;
        for(int i = 1; i <= Q; i++, o ^= 1) {
            int opt = read();
            memcpy(t[o], t[o ^ 1], sizeof(t[o]));
            if(opt == 0) {//
                int id = read(), u = read(), v = read(), p = 1ll * u * inv(v) % mod;
                t[o][id][0] = add(t[o][id][0], mul(p, t[o][id][1]));
                for(int j = 1; j <= a[id]; j++) t[o][id][j] = add(mul(p, t[o ^ 1][id][j + 1]), mul(1 - p, t[o ^ 1][id][j]));
            } else if(opt == 1) {
                int k = read();
                for(int i = 1; i <= k; i++) em[i] = read();
                Pre(o, k);
                for(int i = k; i >= 1; i--) ans[i] = solve(i, o, k);
                for(int i = 1; i <= k; i++) printf("%d ", ans[i]); puts("");
            }
        }
        for(int i = 1; i <= N; i++) {
            int ans = 0;
            for(int j = 1; j <= a[i]; j++) 
                ans = add(ans, mul(j, t[o ^ 1][i][j]));
            printf("%d ", ans);
            
        }
        return 0;
    }
    /*
    */
    
    
  • 相关阅读:
    最优二叉查找树
    最长公共子序列问题
    最大子段和问题
    01背包问题
    浅析LRU(K-V)缓存
    LeetCode——LRU Cache
    LeetCode——Gas Station
    LeetCode——Jump Game II
    LeetCode——Jump Game
    LeetCode——Implement Trie (Prefix Tree)
  • 原文地址:https://www.cnblogs.com/zwfymqz/p/9834570.html
Copyright © 2011-2022 走看看