zoukankan      html  css  js  c++  java
  • BZOJ 4318 OSU! ( 期望DP )

    题目链接

    题意 : OSU 是一款群众喜闻乐见的休闲软件。 我们可以把 OSU 的规则简化与改编成以下的样子 : 一共有 n 次操作,每次操作只有成功与失败之分,成功对应 1 ,失败对应 0 ,n次操作对应为 1 个长度为 n 的 01 串。在这个串中连续的 X 个 1 可以贡献 X^3 的分数,这 X 个 1 不能被其他连续的 1 所包含(也就是极长的一串 1 ,具体见样例解释) 现在给出n,以及每个操作的成功率,请你输出期望分数,输出四舍五入后保留 1 位小数。 【样例说明】 000分数为0,001分数为1,010分数为1,100分数为1,101分数为2,110分数为8,011分数为8,111分数为27,总和为48,期望为48/8=6.0  ( N<=100000 )

     

    分析 :

    考虑 期望DP

    首先得分是由连续 1 的长度决定的

    而且每段连续 1 的贡献是相互独立的

    那么考虑这样一个 len[i] = 以 i 位置为极长 1 的结尾的后缀期望长度

    假设当前考虑到位置 i + 1

    p[i] 为到 i 位置为 1 的概率

    dp[i] 为以 i 位置为极长 1 的结尾的后缀期望得分

    那么其得分期望 ( 即贡献 ) 有如下计算过程

    dp[i + 1] = ( len[i] + 1 )^3 * p[i+1]

    ans += dp[i+1] - dp[i] + 0 * ( 1 - p[i+1] )

           += dp[i+1] - dp[i] 

    后面的 0 * (1 - p[i+1]) 为 i + 1 这个位为 0 的贡献 

    前面的 dp[i+1] - dp[i] 为 i + 1 这个位为 1 的贡献

    为什么是加 dp[i+1] - dp[i] 而不是 dp[i+1] 呢?

    因为 dp[i+1] 本身就是由 dp[i] 递推而来

    即 dp[i+1] = dp[i] + X

    所以相当于 dp 数组是一个前缀期望得分和的形式

    如果要得到每一位的贡献、当然是 dp[i+1] - dp[i]

    每一位加起来就是总贡献

    那么你可能会想能不能直接递推 dp[i+1] 就行了何必那么麻烦

    实际上根据期望的计算公式

    你可以得到 dp[i+1] - dp[i] 这个差值的通式 X

    那么就可以递推 dp[i+1] = dp[i] + X 

    最后答案就是 dp[n]

    下面来讲一下其中期望长度怎么递推、即 len[i]

    上面说到要计算 dp[i + 1] = ( len[i] + 1 )^3 * p[i+1]

    网上很多题解说  E(x^2) != E(x)^2 或 E(x^3) != E(x)^3

    指的就是在 ( len[i] + 1 )^3 的计算这里

    将这条公式展开有

    len[i]^3 + 3*len[i]^2 + 3*len[i] + 1

    那么你不能只算出 len[i] 

    然后计算 dp[i+1] = ( len[i]^3 + 3*len[i]^2 + 3*len[i] + 1 ) * p[i+1]

    你需要另外递推 len[i]^3  和 len[i]^2 的期望

    即先算出 len[i] 记为 a

    再用公式求 len[i]^2 记为 b

    再用公式求 len[i]^3 记为 c

    则 dp[i+1] = (c + 3*b + 3*a + 1)*p[i+1]

    #include<bits/stdc++.h>
    #define LL long long
    #define ULL unsigned long long
    
    #define scl(i) scanf("%lld", &i)
    #define scll(i, j) scanf("%lld %lld", &i, &j)
    #define sclll(i, j, k) scanf("%lld %lld %lld", &i, &j, &k)
    #define scllll(i, j, k, l) scanf("%lld %lld %lld %lld", &i, &j, &k, &l)
    
    #define scs(i) scanf("%s", i)
    #define sci(i) scanf("%d", &i)
    #define scd(i) scanf("%lf", &i)
    #define scIl(i) scanf("%I64d", &i)
    #define scii(i, j) scanf("%d %d", &i, &j)
    #define scdd(i, j) scanf("%lf %lf", &i, &j)
    #define scIll(i, j) scanf("%I64d %I64d", &i, &j)
    #define sciii(i, j, k) scanf("%d %d %d", &i, &j, &k)
    #define scddd(i, j, k) scanf("%lf %lf %lf", &i, &j, &k)
    #define scIlll(i, j, k) scanf("%I64d %I64d %I64d", &i, &j, &k)
    #define sciiii(i, j, k, l) scanf("%d %d %d %d", &i, &j, &k, &l)
    #define scdddd(i, j, k, l) scanf("%lf %lf %lf %lf", &i, &j, &k, &l)
    #define scIllll(i, j, k, l) scanf("%I64d %I64d %I64d %I64d", &i, &j, &k, &l)
    
    #define lson l, m, rt<<1
    #define rson m+1, r, rt<<1|1
    #define lowbit(i) (i & (-i))
    #define mem(i, j) memset(i, j, sizeof(i))
    
    #define fir first
    #define sec second
    #define VI vector<int>
    #define ins(i) insert(i)
    #define pb(i) push_back(i)
    #define pii pair<int, int>
    #define VL vector<long long>
    #define mk(i, j) make_pair(i, j)
    #define all(i) i.begin(), i.end()
    #define pll pair<long long, long long>
    
    #define _TIME 0
    #define _INPUT 0
    #define _OUTPUT 0
    clock_t START, END;
    void __stTIME();
    void __enTIME();
    void __IOPUT();
    using namespace std;
    
    int main(void){__stTIME();__IOPUT();
    
    
        int n;
    
        sci(n);
    
        double len1 = 0, len2 = 0, len3 = 0, dp = 0, ans = 0;
    
        for(int i=1; i<=n; i++){
    
            double p; scd(p);
    
            dp = len3 * p;
    
            ans += (len3 + 3*len2 + 3*len1 + 1) * p - dp;
    
            len3 = (len3 + 3*len2 + 3*len1 + 1) * p;
            len2 = (len2 + 2*len1 + 1) * p;
            len1 = (len1 + 1) * p;
        }
    
        printf("%.1f
    ", len3);
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    __enTIME();return 0;}
    
    
    void __stTIME()
    {
        #if _TIME
            START = clock();
        #endif
    }
    
    void __enTIME()
    {
        #if _TIME
            END = clock();
            cerr<<"execute time = "<<(double)(END-START)/CLOCKS_PER_SEC<<endl;
        #endif
    }
    
    void __IOPUT()
    {
        #if _INPUT
            freopen("in.txt", "r", stdin);
        #endif
        #if _OUTPUT
            freopen("out.txt", "w", stdout);
        #endif
    }
    View Code

    然后来讲一下网上大部分题解的写法递推 dp[i+1] = dp[i] + X

    这个 X 怎么求

    期望的得分 E( (len[i]+1)^3 ) = E( len[i]^3 ) + X

    那么就可以根据这个求出 X = E( (len[i]+1)^3 ) - E( len[i]^3 )

    X = E( (len[i]+1)^3 -  len[i]^3 )

       = E( 3*len[i]^2 + 3*len[i] + 1 )

       = ( 3*len[i]^2 + 3*len[i] + 1 ) * p[i+1]

    故得到递推式子 dp[i+1] = dp[i] + ( 3*len[i]^2 + 3*len[i] + 1 ) * p[i+1]

    在 len[i] 的计算上注意 上面说的 E(x^2) != E(x)^2

    #include<bits/stdc++.h>
    #define LL long long
    #define ULL unsigned long long
    
    #define scl(i) scanf("%lld", &i)
    #define scll(i, j) scanf("%lld %lld", &i, &j)
    #define sclll(i, j, k) scanf("%lld %lld %lld", &i, &j, &k)
    #define scllll(i, j, k, l) scanf("%lld %lld %lld %lld", &i, &j, &k, &l)
    
    #define scs(i) scanf("%s", i)
    #define sci(i) scanf("%d", &i)
    #define scd(i) scanf("%lf", &i)
    #define scIl(i) scanf("%I64d", &i)
    #define scii(i, j) scanf("%d %d", &i, &j)
    #define scdd(i, j) scanf("%lf %lf", &i, &j)
    #define scIll(i, j) scanf("%I64d %I64d", &i, &j)
    #define sciii(i, j, k) scanf("%d %d %d", &i, &j, &k)
    #define scddd(i, j, k) scanf("%lf %lf %lf", &i, &j, &k)
    #define scIlll(i, j, k) scanf("%I64d %I64d %I64d", &i, &j, &k)
    #define sciiii(i, j, k, l) scanf("%d %d %d %d", &i, &j, &k, &l)
    #define scdddd(i, j, k, l) scanf("%lf %lf %lf %lf", &i, &j, &k, &l)
    #define scIllll(i, j, k, l) scanf("%I64d %I64d %I64d %I64d", &i, &j, &k, &l)
    
    #define lson l, m, rt<<1
    #define rson m+1, r, rt<<1|1
    #define lowbit(i) (i & (-i))
    #define mem(i, j) memset(i, j, sizeof(i))
    
    #define fir first
    #define sec second
    #define VI vector<int>
    #define ins(i) insert(i)
    #define pb(i) push_back(i)
    #define pii pair<int, int>
    #define VL vector<long long>
    #define mk(i, j) make_pair(i, j)
    #define all(i) i.begin(), i.end()
    #define pll pair<long long, long long>
    
    #define _TIME 0
    #define _INPUT 0
    #define _OUTPUT 0
    clock_t START, END;
    void __stTIME();
    void __enTIME();
    void __IOPUT();
    using namespace std;
    
    int main(void){__stTIME();__IOPUT();
    
    
        int n;
    
        sci(n);
    
        double len1 = 0, len2 = 0, len3 = 0, dp = 0, ans = 0;
    
        for(int i=1; i<=n; i++){
    
            double p; scd(p);
    
            dp = dp + (3*len2 + 3*len1 + 1) * p;
    
            len3 = (len3 + 3*len2 + 3*len1 + 1) * p;
            len2 = (len2 + 2*len1 + 1) * p;
            len1 = (len1 + 1) * p;
        }
    
        printf("%.1f
    ", dp);
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    __enTIME();return 0;}
    
    
    void __stTIME()
    {
        #if _TIME
            START = clock();
        #endif
    }
    
    void __enTIME()
    {
        #if _TIME
            END = clock();
            cerr<<"execute time = "<<(double)(END-START)/CLOCKS_PER_SEC<<endl;
        #endif
    }
    
    void __IOPUT()
    {
        #if _INPUT
            freopen("in.txt", "r", stdin);
        #endif
        #if _OUTPUT
            freopen("out.txt", "w", stdout);
        #endif
    }
    View Code
  • 相关阅读:
    小问题收集
    JSON.NET与LINQ序列化示例教程
    前台页面中json和字符串相互转化
    jQuery Validate (1)
    jQuery Validate (摘自官网)
    SQL基础(八)-- sql左右连接中的on and 和 on where 的区别
    SQL基础(七)--or和in的使用
    SQL基础(六)--RaiseError的用法
    C#基础(三)--Sort排序
    C#中Equals和==的区别 (面试官经常会问到)
  • 原文地址:https://www.cnblogs.com/qwertiLH/p/9551843.html
Copyright © 2011-2022 走看看