zoukankan      html  css  js  c++  java
  • 【BZOJ3451】Normal-概率期望+点分治+NTT

    测试地址:Normal
    题目大意:将点分治中找分治重心的过程,变成随机在当前块中取一个点,点分治的每一步骤(即处理一块)消耗的时间为块的大小,问总消耗时间的期望。
    做法:本题需要用到概率期望+点分治+NTT。
    首先根据期望的线性性,不难想到分开计算每个点被计算的期望次数,累加起来就是答案。而每个点被计算的次数,等于它在点分树上的深度(根深度为1),那么对于一个点x,某点y(可以是点x自己,它自己一定为自己的祖先)作为点分树上它的祖先的概率,等同于在原树中,点y是在路径xy上的点中第一个被选为分治重心的概率,它们是相互独立的,把这些概率累加起来就是点x的期望深度。具体地,因为每个点被第一次选的概率相同,所以点y作为点x祖先的概率为1dis(x,y),其中dis(x,y)xy路径上点的数目。
    因此答案就是求i=1nj=1n1dis(i,j),暴力计算是O(n2)的,为了加快这个速度,容易想到计算dis为不同数值时的路径数目,这是一个经典的点分治问题,而在具体计算时,有两种可行的写法:
    第一种做法,是在处理某一个分治重心时,将所有分出的子树按大小从小到大排序,然后顺次用FFT/NTT合并信息,显然这样是O(nlog2n)的。
    第二种做法,是在处理某一个分治重心时,先直接用一次FFT/NTT算出该块中过分治重心的路径(可能自交)的信息,然后枚举每棵子树去重,显然这样也是O(nlog2n)的。
    两种做法都可行,而第二种做法写起来更简单,所以这里我用了第二种做法,于是我们就完成了这一题。至于为什么可以用NTT,因为300002<998244353,所以取模后和原值是相同的,NTT写起来又特别方便,还不用担心精度误差,美滋滋。
    我傻逼的地方:TLE,以为是常数写挂,结果是分治重心求错了……简直是太菜了……
    以下是本人代码:

    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const ll mod=998244353;
    const ll g=3;
    int n,first[30010]={0},tot=0,top,q[30010],r[120010];
    int siz[30010],mxson[30010];
    ll now[120010]={0},final[120010]={0};
    bool vis[30010]={0};
    struct edge
    {
        int v,next;
    }e[60010];
    
    void insert(int a,int b)
    {
        e[++tot].v=b;
        e[tot].next=first[a];
        first[a]=tot;
    }
    
    void dp(int v,int fa)
    {
        siz[v]=1;
        mxson[v]=0;
        q[++top]=v;
        for(int i=first[v];i;i=e[i].next)
            if (e[i].v!=fa&&!vis[e[i].v])
            {
                dp(e[i].v,v);
                mxson[v]=max(mxson[v],siz[e[i].v]);
                siz[v]+=siz[e[i].v];
            }
    }
    
    int find(int v)
    {
        top=0;
        dp(v,-1);
        int mn=1000000000,mni;
        for(int i=1;i<=top;i++)
            if (max(mxson[q[i]],siz[v]-siz[q[i]])<mn)
            {
                mn=max(mxson[q[i]],siz[v]-siz[q[i]]);
                mni=q[i];
            }
        return mni;
    }
    
    ll power(ll a,ll b)
    {
        ll s=1,ss=a;
        if (b<0) b+=mod-1;
        while(b)
        {
            if (b&1) s=s*ss%mod;
            ss=ss*ss%mod;b>>=1;
        }
        return s;
    }
    
    void NTT(ll *a,int type,int n)
    {
        for(int i=0;i<=n;i++)
            if (i<r[i]) swap(a[i],a[r[i]]);
        for(int mid=1;mid<n;mid<<=1)
        {
            ll W=power(g,type*(mod-1)/(mid<<1));
            for(int l=0;l<n;l+=(mid<<1))
            {
                ll w=1;
                for(int k=0;k<mid;k++,w=w*W%mod)
                {
                    ll x=a[l+k],y=w*a[l+mid+k]%mod;
                    a[l+k]=(x+y)%mod;
                    a[l+mid+k]=(x-y+mod)%mod;
                }
            }
        }
        if (type==-1)
        {
            ll inv=power(n,mod-2);
            for(int i=0;i<=n;i++)
                a[i]=a[i]*inv%mod;
        }
    }
    
    void calc(int v,int fa,int dis)
    {
        now[dis]++;
        for(int i=first[v];i;i=e[i].next)
            if (e[i].v!=fa&&!vis[e[i].v])
                calc(e[i].v,v,dis+1);
    }
    
    void calctot(int v,int d,int siz,ll type)
    {
        int x=1,bit=0;
        while(x<=(siz<<1)) x<<=1,bit++;
        r[0]=0;
        for(int i=1;i<=x;i++)
            r[i]=(r[i>>1]>>1)|((i&1)<<(bit-1));
    
        for(int i=0;i<=x;i++) now[i]=0;
        calc(v,-1,d);
        NTT(now,1,x);
        for(int i=0;i<=x;i++) now[i]=now[i]*now[i]%mod;
        NTT(now,-1,x);
        for(int i=0;i<=x;i++) final[i]+=type*now[i];
    }
    
    int solve(int v)
    {
        int totsiz=1;
        v=find(v);
        vis[v]=1;
    
        calctot(v,0,siz[v],1);
        for(int i=first[v];i;i=e[i].next)
            if (!vis[e[i].v])
            {
                int newsiz=solve(e[i].v);
                calctot(e[i].v,1,newsiz,-1);
                totsiz+=newsiz;
            }
    
        vis[v]=0;
        return totsiz;
    }
    
    int main()
    {
        scanf("%d",&n);
        for(int i=1;i<n;i++)
        {
            int a,b;
            scanf("%d%d",&a,&b);
            insert(a,b),insert(b,a);
        }
    
        solve(0);
        double ans=0.0;
        for(int i=0;i<=n;i++)
            ans+=(double)final[i]/(double)(i+1);
        printf("%.4lf",ans);
    
        return 0;
    }
  • 相关阅读:
    CCF认证201809-2买菜
    git删除本地保存的账号和密码
    mysql表分区
    使用java代码将时间戳和时间互相转换
    Mysql数据库表被锁定处理
    mysql查询某个数据库表的数量
    编译nginx错误:make[1]: *** [/pcre//Makefile] Error 127
    LINUX下安装pcre出现WARNING: 'aclocal-1.15' is missing on your system错误的解决办法
    linux下安装perl
    [剑指Offer]26-树的子结构
  • 原文地址:https://www.cnblogs.com/Maxwei-wzj/p/9793344.html
Copyright © 2011-2022 走看看