zoukankan      html  css  js  c++  java
  • [PKUWC2018]Minimax 题解

    根据题意,若一个点有子节点,则给出权值;否则可以从子节点转移得来。

    若没有子节点,则直接给出权值;

    若只有一个子节点,则概率情况与该子节点完全相同;

    若有两个子节点,则需要从两个子节点中进行转移。

    如何转移?显然,若权值 $i$ 在左子树,要取到它,需要在 $p_i$ 的概率中左子树较大,在 $(1-p_i)$ 的概率中左子树较小,右子树同理。因为当权值 $i$ 在左子树时右子树取到它的概率为 $0$ ,因此可以直接将两个子树的转移式相加合并,没有影响。即:

    设节点 $x$ 取到权值 $i$ 的概率为 $f_{x,i}$ ,节点数为 $n$ ,则有:

    $$f_{x,i}=f_{lson(x),i}*(p_x*sum_{j=1}^{i-1} f_{rson(x),j}+(1-p_x)*sum_{j=i+1}^{n} f_{rson(x),j})+f_{rson(x),i}*(p_x*sum_{j=1}^{i-1} f_{lson(x),j}+(1-p_x)*sum_{j=i+1}^{n} f_{lson(x),j})$$

    如何求值?通过观察可以发现,这个式子同时需要左儿子的值、右儿子的值以及其前缀、后缀和,可以想到用线段树合并进行求值。

    如何实现?在线段树合并的同时维护前、后缀和,打上乘法标记即可。记得离散化。

    本题需要用到分数取模:$frac{a}{b}mod k=frac{a}{b}*b^{k-1}mod k=a*b^{k-2}mod k$ 。

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    using namespace std;
    const int N=3e5+100;
    const int P=998244353;
    int qpow(int x,int y)
    {
        int ans=1;
        while(y)
        {
            if(y&1)
                ans=1LL*ans*x%P;
            x=1LL*x*x%P,y>>=1;
        }
        return ans;
    }//快速幂
    const int D=qpow(10000,P-2);
    struct Seg
    {
        int lson,rson,sum,tag;
        #define lson(i) t[i].lson
        #define rson(i) t[i].rson
        #define sum(i) t[i].sum
        #define tag(i) t[i].tag//乘法标记
    }t[N<<5];
    int n,cnt,h,ans;
    int v[N],r[N],a[N],cnts[N],anss[N],s[N][2];
    int newp()
    {
        int p=++cnt;
        tag(p)=1;
        return p;
    }//新建节点
    void update(int p,int val)
    {
        sum(p)=1LL*sum(p)*val%P;
        tag(p)=1LL*tag(p)*val%P;
    }//更新
    void pushdown(int p)
    {
        if(tag(p)>1)
            update(lson(p),tag(p)),update(rson(p),tag(p)),tag(p)=1;
    }//下传标记
    void change(int &p,int l,int r,int k,int val)
    {
        if(!p)
            p=newp();
        if(l==r)
        {
            sum(p)=val;
            return ;
        }
        pushdown(p);
        int mid=l+r>>1;
        if(k<=mid)
            change(lson(p),l,mid,k,val);
        else
            change(rson(p),mid+1,r,k,val);
        sum(p)=sum(lson(p))+sum(rson(p))%P;
    }//修改
    void ask(int p,int l,int r)
    {
        if(!p)
            return ;
        if(l==r)
        {
            anss[l]=sum(p);
            return ;
        }
        pushdown(p);
        int mid=l+r>>1;
        ask(lson(p),l,mid),ask(rson(p),mid+1,r);
    }//输出答案到对应数组
    int merge(int x,int y,int l,int r,int xtag,int ytag,int val)
    {
        if(!x && !y)
            return 0;
        if(!x)
        {
            update(y,ytag);
            return y;
        }
        if(!y)
        {
            update(x,xtag);
            return x;
        }
        pushdown(x),pushdown(y);//下传标记
        int mid=l+r>>1,lx=sum(lson(x)),ly=sum(lson(y)),rx=sum(rson(x)),ry=sum(rson(y));//先存值,否则之后会被改动
        lson(x)=merge(lson(x),lson(y),l,mid,(xtag+1LL*ry*(1-val+P)%P)%P,(ytag+1LL*rx*(1-val+P)%P)%P,val);
        rson(x)=merge(rson(x),rson(y),mid+1,r,(xtag+1LL*ly*val%P)%P,(ytag+1LL*lx*val%P)%P,val);
        sum(x)=(sum(lson(x))+sum(rson(x)))%P;
        return x;
    }
    void pre()
    {
        sort(a+1,a+h+1);
        for(int i=1;i<=n;i++)
            if(!cnts[i])
                v[i]=lower_bound(a+1,a+h+1,v[i])-a;//权值离散化
            else
                v[i]=1LL*v[i]*D%P;//存概率,分数取模
    }//预处理权值
    void solve(int x)
    {
        if(!cnts[x])
        {
            change(r[x],1,h,v[x],1);
            return ;
        }//没有子节点,插入权值
        if(cnts[x]==1)
        {
            solve(s[x][0]);
            r[x]=r[s[x][0]];
            return ;
        }//只有一个子节点
        solve(s[x][0]),solve(s[x][1]);
        r[x]=merge(r[s[x][0]],r[s[x][1]],1,h,0,0,v[x]);//有两个子节点
    }
    int main()
    {
        scanf("%d",&n);
        for(int i=1,x;i<=n;i++)
        {
            scanf("%d",&x);
            if(x)
                s[x][cnts[x]++]=i;
        }
        for(int i=1,x;i<=n;i++)
        {
            scanf("%d",&v[i]);
            if(!cnts[i])
                a[++h]=v[i];
        }
        pre(),solve(1),ask(r[1],1,h);
        for(int i=1;i<=h;i++)
            ans=(ans+1LL*i*a[i]%P*anss[i]%P*anss[i])%P;//计算答案
        printf("%d",ans);
        return 0;
    }
  • 相关阅读:
    postgresql postgres.Jsonb
    gorm jsonb
    json
    postgresql重置序列和自增主键
    go build x509: certificate has expired or is not yet valid
    权限修饰符
    交换两个变量的值
    Java编译报错:编码GBK的不可映射字符
    原码反码补码的理解
    Scanner类中hasNext()方法的解析
  • 原文地址:https://www.cnblogs.com/TEoS/p/13289222.html
Copyright © 2011-2022 走看看