zoukankan      html  css  js  c++  java
  • 【BZOJ3451】Tyvj1953 Normal

    题目来源:NOI2019模拟测试赛(七)

    非原题面,题意有略微区别

    题意:

    吐槽:

    心态崩了。

    好不容易场上想出一题正解,写了三个小时结果写了个假的点分治,卡成$O(n^2)$

    我退役吧。

    题解:

    原题是求随机树分治的期望深度和,题意相同。

    对于一个点$x$,考虑点$y$是否能作为它在点分树上的祖先节点,显然当且仅当$y$在$x$到$y$的路径中第一个被选为分治中心时会对$x$产生1的贡献;

    由于路径上所有点被选到的概率都是相等的,所以此时的期望就是$frac{1}{dis(x,y)}$;

    那么总的期望就是$sumlimits_{x=1}^{n}sumlimits_{y=1}^{n}frac{1}{dis(x,y)}$;

    在这里写个暴力即可爆踩我的假点分治;

    考虑统计每种长度的路径条数,可以用点分治做,并且在点分树里合并时子树的期望是一个卷积的形式,因此可以用FFT来加速;

    于是我就快乐的写了个点分治+FFT,获得了60分的好成绩;

    为什么?参考这篇博客的证明,我最初的写法就是其中的第一种写法,搜完一个子树就和已经搜过的合并,这样做的话FFT的长度会是$子树中最大深度 imes 根节点儿子个数=O(n^2)$的,正确的写法应该搜完再一起合并,或者像里面说的第二种方法一样直接搜当前子树,更新答案然后搜重心的每个儿子的子树,减去不合法的路径,这样子FFT的长度才是$O(n)$的。

    代码:

    假点分治(60pts):

      1 #include<algorithm>
      2 #include<iostream>
      3 #include<cstring>
      4 #include<cstdio>
      5 #include<cmath>
      6 #include<queue>
      7 #define inf 2147483647
      8 #define eps 1e-9
      9 #define mod 1000000007
     10 using namespace std;
     11 typedef long long ll;
     12 typedef double db;
     13 const db pi=acos(-1.0);
     14 
     15 struct edge{
     16     int v,next;
     17 }a[200001];
     18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],s[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001];
     19 bool used[100001];
     20 struct cp{
     21     db a,b;
     22     cp(){}
     23     cp(db _a,db _b){
     24         a=_a,b=_b;
     25     }
     26     friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}
     27     friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}
     28     friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
     29     friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}
     30     friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}
     31 }A[200001],B[200001],W[200001][2];
     32 void _(){
     33     for(int i=1;i<=(1<<17);i<<=1){
     34         W[i][0]=cp(cos(pi/i),sin(pi/i));
     35         W[i][1]=cp(cos(pi/i),-sin(pi/i));
     36     }
     37 }
     38 void fft(cp *s,int op){
     39     for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
     40     for(int i=1;i<bit;i<<=1){
     41         //cp w(cos(pi/i),op*sin(pi/i));
     42         cp w=W[i][op==-1];
     43         for(int p=i<<1,j=0;j<bit;j+=p){
     44             cp wk(1,0);
     45             for(int k=j;k<i+j;k++,wk=wk*w){
     46                 cp x=s[k],y=wk*s[k+i];
     47                 s[k]=x+y;
     48                 s[k+i]=x-y;
     49             }
     50         }
     51     }
     52     if(op==-1){
     53         for(int i=0;i<bit;i++){
     54             s[i]=s[i]/(db)bit;
     55         }
     56     }
     57 }
     58 void add(int u,int v){
     59     a[++tot].v=v;
     60     a[tot].next=head[u];
     61     head[u]=tot;
     62 }
     63 void mul(int *ret,int *a,int *b,int n){
     64     for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;
     65     for(int i=1;i<=bit;i++){
     66         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));
     67     }
     68     for(int i=0;i<bit;i++){
     69         A[i]=cp((db)a[i],0);
     70         B[i]=cp(0,0);
     71     }
     72     for(int i=1;i<=cnt;i++){
     73         a[b[i]]++;
     74         B[b[i]].a+=1;
     75     }
     76     fft(A,1);
     77     fft(B,1);
     78     for(int i=0;i<bit;i++)A[i]=A[i]*B[i];
     79     fft(A,-1);
     80     for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);
     81 }
     82 void getrt(int u,int fa){
     83     mx[u]=0;
     84     siz[u]=1;
     85     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
     86         int v=a[tmp].v;
     87         if(!used[v]&&v!=fa){
     88             getrt(v,u);
     89             siz[u]+=siz[v];
     90             mx[u]=max(mx[u],siz[v]);
     91         }
     92     }
     93     mx[u]=max(mx[u],S-mx[u]);
     94     if(mx[u]<mx[rt])rt=u;
     95 }
     96 void getdep(int u,int fa,int dpt){
     97     mxd=max(mxd,dpt);
     98     s[++cnt]=dpt;
     99     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
    100         int v=a[tmp].v;
    101         if(!used[v]&&v!=fa){
    102             getdep(v,u,dpt+1);
    103         }
    104     }
    105 }
    106 void divide(int u){
    107     used[u]=true;
    108     mxd=0;
    109     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
    110         int v=a[tmp].v;
    111         if(!used[v]){
    112             cnt=0;
    113             getdep(v,u,1);
    114             mul(tp,num,s,mxd);
    115             for(int i=0;i<bit;i++)anss[i]+=tp[i];
    116         }
    117     }
    118     for(int i=1;i<=mxd;i++){
    119         anss[i]+=num[i];
    120         num[i]=0;
    121     }
    122     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
    123         int v=a[tmp].v;
    124         if(!used[v]){
    125             S=siz[v];
    126             rt=0;
    127             getrt(v,0);
    128             divide(rt);
    129         }
    130     }
    131 }
    132 int main(){
    133     memset(head,-1,sizeof(head));
    134     _();
    135     scanf("%d",&n);
    136     jc[0]=inv[0]=inv[1]=1;
    137     for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
    138     for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod;
    139     for(int i=1;i<n;i++){
    140         scanf("%d%d",&u,&v);
    141         add(u,v);
    142         add(v,u);
    143     }
    144     S=n;
    145     mx[rt=0]=6666666;
    146     getrt(1,-1);
    147     divide(rt);
    148     ans=n;
    149     for(int i=1;i<=n;i++){
    150         ans=(ans+(ll)anss[i]*inv[i+1]*2%mod)%mod;
    151     }
    152     printf("%lld",(ll)ans*jc[n]%mod);
    153     return 0;
    154 }

    AC代码(100pts):

      1 #include<algorithm>
      2 #include<iostream>
      3 #include<cstring>
      4 #include<cstdio>
      5 #include<cmath>
      6 #include<queue>
      7 #define inf 2147483647
      8 #define eps 1e-9
      9 #define mod 1000000007
     10 using namespace std;
     11 typedef long long ll;
     12 typedef double db;
     13 const db pi=acos(-1.0);
     14 
     15 struct edge{
     16     int v,next;
     17 }a[200001];
     18 int n,u,v,S,rt,mxd,bit,bitnum,tot=0,cnt=0,ans=0,jc[100001],inv[100001],anss[200001],tp[200001],num[200001],rev[200001],head[100001],mx[100001],siz[100001],dep[100001],dps[100001];
     19 bool used[100001];
     20 struct cp{
     21     db a,b;
     22     cp(){}
     23     cp(db _a,db _b){
     24         a=_a,b=_b;
     25     }
     26     friend cp operator +(cp a,cp b){return cp(a.a+b.a,a.b+b.b);}
     27     friend cp operator -(cp a,cp b){return cp(a.a-b.a,a.b-b.b);}
     28     friend cp operator *(cp a,cp b){return cp(a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a);}
     29     friend cp operator *(cp a,db b){return cp(a.a*b,a.b*b);}
     30     friend cp operator /(cp a,db b){return cp(a.a/b,a.b/b);}
     31 }A[200001],B[200001],W[200001][2];
     32 void _(){
     33     for(int i=1;i<=(1<<17);i<<=1){
     34         W[i][0]=cp(cos(pi/i),sin(pi/i));
     35         W[i][1]=cp(cos(pi/i),-sin(pi/i));
     36     }
     37 }
     38 void fft(cp *s,int op){
     39     for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
     40     for(int i=1;i<bit;i<<=1){
     41         //cp w(cos(pi/i),op*sin(pi/i));
     42         cp w=W[i][op==-1];
     43         for(int p=i<<1,j=0;j<bit;j+=p){
     44             cp wk(1,0);
     45             for(int k=j;k<i+j;k++,wk=wk*w){
     46                 cp x=s[k],y=wk*s[k+i];
     47                 s[k]=x+y;
     48                 s[k+i]=x-y;
     49             }
     50         }
     51     }
     52     if(op==-1){
     53         for(int i=0;i<bit;i++){
     54             s[i]=s[i]/(db)bit;
     55         }
     56     }
     57 }
     58 void add(int u,int v){
     59     a[++tot].v=v;
     60     a[tot].next=head[u];
     61     head[u]=tot;
     62 }
     63 void mul(int *ret,int *a,int *b,int n){
     64     for(bit=1,bitnum=0;bit<=n*2;bit<<=1)bitnum++;
     65     for(int i=1;i<bit;i++){
     66         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bitnum-1));
     67     }
     68     for(int i=0;i<bit;i++){
     69         A[i]=cp((db)a[i],0);
     70         B[i]=cp((db)b[i],0);
     71     }
     72     fft(A,1);
     73     fft(B,1);
     74     for(int i=0;i<bit;i++)A[i]=A[i]*B[i];
     75     fft(A,-1);
     76     for(int i=0;i<bit;i++)ret[i]=(int)(A[i].a+0.5);
     77 }
     78 void getrt(int u,int fa){
     79     mx[u]=0;
     80     siz[u]=1;
     81     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
     82         int v=a[tmp].v;
     83         if(!used[v]&&v!=fa){
     84             getrt(v,u);
     85             siz[u]+=siz[v];
     86             mx[u]=max(mx[u],siz[v]);
     87         }
     88     }
     89     mx[u]=max(mx[u],S-mx[u]);
     90     if(mx[u]<mx[rt])rt=u;
     91 }
     92 void getdep(int u,int fa,int dpt){
     93     mxd=max(mxd,dpt);
     94     dps[dpt]++;
     95     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
     96         int v=a[tmp].v;
     97         if(!used[v]&&v!=fa){
     98             getdep(v,u,dpt+1);
     99         }
    100     }
    101 }
    102 void divide(int u){
    103     used[u]=true;
    104     num[0]=1;
    105     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
    106         int v=a[tmp].v;
    107         if(!used[v]){
    108             getdep(v,u,1);
    109             for(int i=1;i<=mxd;i++){
    110                 num[i]+=dps[i];
    111                 tp[i]=dps[i];
    112                 dps[i]=0;
    113             }
    114             cnt=max(cnt,mxd);
    115             mul(tp,tp,tp,mxd);
    116             for(int i=1;i<=mxd*2;i++){
    117                 anss[i]-=tp[i];
    118                 tp[i]=0;
    119             }
    120             mxd=0;
    121         }
    122     }
    123     for(int i=0;i<=cnt;i++){
    124         tp[i]=num[i];
    125         num[i]=0;
    126     }
    127     mul(tp,tp,tp,cnt);
    128     for(int i=0;i<=cnt*2;i++){
    129         anss[i]+=tp[i];
    130         tp[i]=0;
    131     }
    132     cnt=0;
    133     for(int tmp=head[u];tmp!=-1;tmp=a[tmp].next){
    134         int v=a[tmp].v;
    135         if(!used[v]){
    136             S=siz[v];
    137             rt=0;
    138             getrt(v,0);
    139             divide(rt);
    140         }
    141     }
    142 }
    143 int main(){
    144     memset(head,-1,sizeof(head));
    145     _();
    146     scanf("%d",&n);
    147     jc[0]=inv[0]=inv[1]=1;
    148     for(int i=2;i<=n+1;i++)inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
    149     for(int i=1;i<=n+1;i++)jc[i]=(ll)jc[i-1]*i%mod;
    150     for(int i=1;i<n;i++){
    151         scanf("%d%d",&u,&v);
    152         add(u,v);
    153         add(v,u);
    154     }
    155     S=n;
    156     mx[rt=0]=6666666;
    157     getrt(1,-1);
    158     divide(rt);
    159     ans=n;
    160     for(int i=1;i<=n;i++){
    161         ans=(ans+(ll)anss[i]*inv[i+1]%mod)%mod;
    162     }
    163     printf("%lld",(ll)ans*jc[n]%mod);
    164     return 0;
    165 }
  • 相关阅读:

    IT人的素质 & 设计杂谈
    结构化思维思维的结构
    [WM].NET CF下如何提高应用程序的性能 【转载】
    无题
    [WM]谁抢走了应用程序的性能? 【转载】
    繁体编码文本文件转换为简体编码的工具
    生成VB多行字符串常量的工具
    跟我一步一步开发自己的Openfire插件
    cnblogs博文浏览[推荐、Top、评论、关注、收藏]利器代码片段
  • 原文地址:https://www.cnblogs.com/dcdcbigbig/p/10140006.html
Copyright © 2011-2022 走看看