zoukankan      html  css  js  c++  java
  • Luogu P5450 [THUPC2018]淘米神的树

    题意

    写的很明白了,不需要解释。

    ( exttt{Data Range:}1leq nleq 234567)

    题解

    国 际 计 数 水 平

    首先考虑一开始只有一个黑点的情况怎么做。

    我们钦定黑点为根,设 (f_i) 表示以 (i) 为子树,并且 (i) 这个节点为黑色的答案,那么我们得到:

    [f_u=(sz_u-1)!prodlimits_{fa_v=u}frac{f_v}{sz_v!} ]

    接下来我们阐述这个式子到底是在干什么。

    首先我们枚举可能的排列,由于现在只有 (u) 是黑色的,第一轮肯定是先把 (u) 变成红色。也就是说,(u) 只能放在第一个,而剩下的可以任意。(注意这里是可能而不是合法)

    接下来考虑排除掉所有不合法的排列,枚举每一棵子树,考虑某棵子树所有节点之间的偏序关系。

    对于一个树 (T) 对应的可能的序列 (P) 来说,我们按顺序从 (P) 中选出一些元素构造 (Q),使得 (Q) 中的每一个元素都是 (T) 中某个子树 (S) 中存在的节点的编号,并且 (S) 中每个节点的编号都要在 (Q) 中出现。

    这里,(T) 的根为 (u)(S) 的根为 (v)(fa_v=u)

    然后我们发现,如果序列 (Q) 是不合法的,那么序列 (P) 也是不合法的。

    对所有子树都这么做就得到同时满足所有约束的序列个数,也就是上面的式子。

    接下来我们考虑直接计算根节点的答案。

    [f_{rt}=frac{(sz_{rt}-1)!}{prodlimits_{fa_v=rt}sz_v!}prodlimits_{fa_v=rt}f_v ]

    然后我们对于所有根节点的儿子 (v) 代入进来,再将所有孙子代入进来……相当于我们枚举所有节点计算,于是有

    [f_{rt}=prodlimits_{u=1}^{n}frac{(sz_{u}-1)!}{prodlimits_{fa_v=u}sz_v!} ]

    简单化开

    [f_{rt}=frac{prodlimits_{u=1}^{n}(sz_{u}-1)!}{prodlimits_{u=1}^{n}prodlimits_{fa_v=u}sz_v!} ]

    注意分母,也就是说我们对于每个节点 (u) 都枚举 (u) 的所有儿子,相当于枚举所有非根节点(因为每个非根节点总会是某个节点的儿子),于是有

    [f_{rt}=frac{prodlimits_{u=1}^{n}(sz_{u}-1)!}{prodlimits_{u=1,u eq rt}^{n}sz_u!} ]

    现在来拆分子。单独把 (u=rt) 的拿出来,剩下的合并并约分,得到

    [f_{rt}=(sz_{rt}-1)!prodlimits_{u=1,u eq rt}^{n}frac{1}{sz_u} ]

    然后分子分母同乘一个 (sz_{rt}) 得到答案为

    [f_{rt}=sz_{rt}!prodlimits_{u=1}^{n}frac{1}{sz_u} ]

    接下来考虑两个黑点咋做,建立一个虚拟节点 (s),向 (a,b) 连无向边,初始的时候只有 (s) 是黑点。

    第一次只能变 (s),然后就变成 (s) 是红点,(a,b) 是黑点。容易证明这样对答案没有任何影响。

    于是树就变成了一颗基环树,考虑怎么求答案。

    我们枚举环上染红的最后一个点的位置,但是这是不可行的。

    那么我们就用类似于 NOIp2018 Day 2 T1 的方法,考虑枚举一条在环上的边并删除他来计算答案。

    但是这样做可能会用重复的答案,所以我们要考虑如何来排除这些答案。

    注意到对于一种染色方案来说,考虑这个方案在环上最后一个被染成红色的点。我们发现这个点可以有两种方案被染黑,也就是说每个方案都会算 (2) 此,于是最终答案要乘上 (frac{1}{2})

    对于不是环上的点和虚拟节点 (s) 我们都能用 ( exttt{dfs}) 求出子树大小,设他们的和为 (z)

    我们现在考虑这些在环上的点,对于某个在环上的点 (u) 来说,设 (a_u) 表示它的确定子树大小(也就是所有不在环上的儿子的大小之和加一),(b_u) 表示某种断环方式下的 (sz_u),那么通过人类智慧我们发现

    将环上的 (k+1) 个点编号,(s) 点为 (0),其余的依次为 (1sim k),假设断的边是 (r o r+1),那么

    [b_u=egin{cases}sumlimits_{i=1}^{u}a_u,& 1leq uleq r\sumlimits_{i=u}^{k}a_u,&r+1leq uleq kend{cases} ]

    我们对 (a) 做前缀和变成 (c),那么可以发现

    [prod b=prod_{i eq j}vert c_i-c_jvert ]

    其中 (c_0=0)

    然后我们可以用快速插值的套路去优化这个东西,复杂度 (O(nlog^2n))

    代码

    #include<cstdio>
    #include<cstring>
    #include<cctype>
    #include<cmath>
    #include<iostream>
    #include<algorithm>
    #include<vector>
    #define clr(f,n) memset(f,0,sizeof(int)*(n))
    #define cpy(f,g,n) memcpy(f,g,sizeof(int)*(n))
    #pragma GCC optimize("Ofast")
    #pragma GCC optimize("unroll-loops")
    using namespace std;
    typedef int ll;
    typedef long long int li;
    typedef unsigned long long int ull;
    const ll MAXN=524291,MOD=998244353,G=3,INVG=332748118;
    struct Edge{
        ll to,prev;
    };
    Edge ed[MAXN];
    ll n,x,y,z,fct,from,to,tot,tp,cur,res;
    ll rev[MAXN],omgs[MAXN],invo[MAXN];
    ll last[MAXN],sz[MAXN],fa[MAXN],st[MAXN],p[MAXN],c[MAXN];
    ll pr[MAXN],px[MAXN];
    inline ll read()
    {
        register ll num=0,neg=1;
        register char ch=getchar();
        while(!isdigit(ch)&&ch!='-')
        {
            ch=getchar();
        }
        if(ch=='-')
        {
            neg=-1;
            ch=getchar();
        }
        while(isdigit(ch))
        {
            num=(num<<3)+(num<<1)+(ch-'0');
            ch=getchar();
        }
        return num*neg;
    }
    inline void addEdge(ll from,ll to)
    {
        ed[++tot].prev=last[from];
        ed[tot].to=to;
        last[from]=tot;
    }
    inline void dfs(ll node,ll f)
    {
        sz[node]=1,fa[node]=f;
        for(register int i=last[node];i;i=ed[i].prev)
        {
            if(ed[i].to!=f)
            {
                dfs(ed[i].to,node),sz[node]+=sz[ed[i].to];
            }
        }
    }
    inline ll qpow(ll base,ll exponent)
    {
        ll res=1;
        while(exponent)
        {
            if(exponent&1)
            {
                res=(li)res*base%MOD;
            }
            base=(li)base*base%MOD,exponent>>=1;
        }
        return res;
    }
    inline void setupOmg(ll cnt)
    {
        ll limit=log2(cnt)-1,omg;
        omg=qpow(G,(MOD-1)>>(limit+1)),omgs[cnt>>1]=1;
        for(register int i=(cnt>>1|1);i!=cnt;i++)
        {
            omgs[i]=(li)omgs[i-1]*omg%MOD;
        }
        for(register int i=(cnt>>1)-1;i;i--)
        {
            omgs[i]=omgs[i<<1]; 
        }
    }
    inline void NTT(ll *cp,ll cnt,ll inv)
    {
        static ull tcp[MAXN];
        register ll cur=0,x,shift=log2(cnt)-__builtin_ctz(cnt);
        if(inv==-1)
        {
            reverse(cp+1,cp+cnt);
        }
        for(register int i=0;i<cnt;i++)
        {
            tcp[rev[i]>>shift]=cp[i];
        }
        for(register int i=2;i<=cnt;i<<=1)
        {
            cur=i>>1;
            for(register int j=0;j<cnt;j+=i)
            {
                for(register int k=0;k<cur;k++)
                {
                    x=tcp[j|k|cur]*omgs[k|cur]%MOD;
                    tcp[j|k|cur]=tcp[j|k]+MOD-x,tcp[j|k]+=x;
                }
            }
        }
        for(register int i=0;i<cnt;i++)
        {
            cp[i]=tcp[i]%MOD;
        }
        if(inv==1)
        {
            return;
        }
        x=MOD-(MOD-1)/cnt;
        for(register int i=0;i<cnt;i++)
        {
            cp[i]=(li)cp[i]*x%MOD;
        }
    }
    inline void conv(ll fd,ll *f,ll *g,ll *res)
    {
        static ll tmpf[MAXN],tmpg[MAXN];
        ll cnt=1,limit=-1;
        while(cnt<(fd<<1))
        {
            cnt<<=1,limit++;
        }
        for(register int i=0;i<cnt;i++)
        {
            tmpf[i]=i<fd?f[i]:0,tmpg[i]=i<fd?g[i]:0;
            rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
        }
        NTT(tmpf,cnt,1),NTT(tmpg,cnt,1);
        for(register int i=0;i<cnt;i++)
        {
            tmpf[i]=(li)tmpf[i]*tmpg[i]%MOD;
        }
        NTT(tmpf,cnt,-1),cpy(res,tmpf,fd);
    }
    inline void inv(ll fd,ll *f,ll *res)
    {
        static ll tmp[MAXN];
        if(fd==1)
        {
            res[0]=qpow(f[0],MOD-2);
            return;
        }
        inv((fd+1)>>1,f,res);
        ll cnt=1,limit=-1;
        while(cnt<(fd<<1))
        {
            cnt<<=1,limit++;
        }
        for(register int i=0;i<cnt;i++)
        {
            tmp[i]=i<fd?f[i]:0;
            rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
        }
        NTT(tmp,cnt,1),NTT(res,cnt,1);
        for(register int i=0;i<cnt;i++)
        {
            res[i]=(li)(2-(li)tmp[i]*res[i]%MOD+MOD)%MOD*res[i]%MOD;
        }
        NTT(res,cnt,-1),clr(res+fd,cnt-fd-1);
    }
    inline void mod(ll fd,ll gd,ll *f,ll *g,ll *r)
    {
        static ll tmpf[MAXN],tmpg[MAXN],tinv[MAXN],q[MAXN];
        if(fd<gd)
        {
            for(register int i=0;i<gd-1;i++)
            {
                r[i]=f[i];
            }
            return;
        }
        for(register int i=0;i<fd;i++)
        {
            tmpf[i]=f[fd-1-i];
        }
        for(register int i=0;i<gd;i++)
        {
            tmpg[i]=g[gd-1-i];
        }
        inv(fd-gd+2,tmpg,tinv);
        ll cnt=1,limit=-1;
        while(cnt<(fd<<1))
        {
            cnt<<=1,limit++;
        }
        for(register int i=0;i<cnt;i++)
        {
            rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
        }
        NTT(tmpf,cnt,1),NTT(tinv,cnt,1);
        for(register int i=0;i<cnt;i++)
        {
            q[i]=1ll*tmpf[i]*tinv[i]%MOD;
        }
        NTT(q,cnt,-1),reverse(q,q+fd-gd+1);
        for(register int i=0;i<cnt;i++)
        {
            tmpf[i]=tinv[i]=tmpg[i]=0;
            q[i]=i<fd-gd+1?q[i]:0,g[i]=i<gd?g[i]:0;
        }
        cnt>>=1,limit--;
        for(register int i=0;i<cnt;i++)
        {
            rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
        }
        NTT(q,cnt,1),NTT(g,cnt,1);
        for(register int i=0;i<cnt;i++)
        {
            tmpf[i]=1ll*q[i]*g[i]%MOD;
        }
        NTT(g,cnt,-1),NTT(tmpf,cnt,-1);
        for(register int i=0;i<gd-1;i++)
        {
            r[i]=(f[i]-tmpf[i]+MOD)%MOD;
        }
        for(register int i=0;i<cnt;i++)
        {
            q[i]=tmpf[i]=0;
        }
    }
    vector<ll> tmpf2[MAXN<<2];
    void dnc(ll *pts,ll l,ll r,ll node)
    {
        static ll tmp[MAXN],tmp2[MAXN];
        if(l==r)
        {
            tmpf2[node].push_back((MOD-pts[l])%MOD),tmpf2[node].push_back(1);
            return;
        }
        ll mid=(l+r)>>1,ls=node<<1,rs=ls|1;
        dnc(pts,l,mid,ls),dnc(pts,mid+1,r,rs);
        ll d=tmpf2[ls].size(),d2=tmpf2[rs].size();
        copy(tmpf2[ls].begin(),tmpf2[ls].end(),tmp);
        copy(tmpf2[rs].begin(),tmpf2[rs].end(),tmp2);
        ll cnt=1,limit=-1;
        while(cnt<(d+d2))
        {
            cnt<<=1,limit++;
        }
        for(register int i=0;i<cnt;i++)
        {
            rev[i]=(rev[i>>1]>>1)|((i&1)<<limit);
        }
        NTT(tmp,cnt,1),NTT(tmp2,cnt,1);
        for(register int i=0;i<cnt;i++)
        {
            tmp[i]=(li)tmp[i]*tmp2[i]%MOD;
        }
        NTT(tmp,cnt,-1),tmpf2[node].resize(d+d2-1);
        copy(tmp,tmp+d+d2-1,tmpf2[node].begin()),clr(tmp,cnt),clr(tmp2,cnt);
    }
    ll tmpf3[19][MAXN];
    void dnc2(ll fd,ll *f,ll depth,ll l,ll r,ll node,ll *res)
    {
        static ll tmp[MAXN],pw[17];
        if(r-l<=1024)
        {
            for(register int i=l;i<=r;i++)
            {
                ll x=c[i],cur=f[fd-1],v1,v2,v3,v4;
                pw[0]=1;
                for(register int j=1;j<=16;j++)
                {
                    pw[j]=1ll*pw[j-1]*x%MOD;
                }
                for(register int j=fd-2;j-15>=0;j-=16)
                {
                    v1=(1ll*cur*pw[16]+1ll*f[j]*pw[15]+
                        1ll*f[j-1]*pw[14]+1ll*f[j-2]*pw[13])%MOD;
                    v2=(1ll*f[j-3]*pw[12]+1ll*f[j-4]*pw[11]+
                        1ll*f[j-5]*pw[10]+1ll*f[j-6]*pw[9])%MOD;
                    v3=(1ll*f[j-7]*pw[8]+1ll*f[j-8]*pw[7]+
                        1ll*f[j-9]*pw[6]+1ll*f[j-10]*pw[5])%MOD;
                    v4=(1ll*f[j-11]*pw[4]+1ll*f[j-12]*pw[3]+
                        1ll*f[j-13]*pw[2]+1ll*f[j-14]*pw[1])%MOD;
                    cur=(0ll+v1+v2+v3+v4+f[j-15])%MOD;
                }
                for(register int j=((fd-1)&15)-1;~j;j--)
                {
                    cur=(1ll*cur*x+f[j])%MOD;
                }
                res[i]=cur;
            }
            return;
        }
        ll sz=tmpf2[node].size()-1;
        for(register int i=0;i<sz+1;i++)
        {
            tmp[i]=tmpf2[node][i];
        }
        clr(tmpf3[depth],sz+10),mod(fd,sz+1,f,tmp,tmpf3[depth]);
        ll mid=(l+r)>>1;
        dnc2(sz,tmpf3[depth],depth+1,l,mid,node<<1,res);
        dnc2(sz,tmpf3[depth],depth+1,mid+1,r,(node<<1)|1,res);
        for(register int i=0;i<sz;i++)
        {
            tmpf3[depth][i]=0;
        }
    }
    inline void eval(ll fd,ll pcnt,ll *f,ll *pts,ll *res)
    {
        dnc(pts,0,pcnt-1,1),dnc2(fd,f,0,0,pcnt-1,1,res);
    }
    int main()
    {
        setupOmg(524288),n=read(),x=read(),y=read();
        for(register int i=0;i<n-1;i++)
        {
            from=read(),to=read();
            addEdge(from,to),addEdge(to,from);
        }
        dfs(x,0);
        while(y!=x)
        {
            st[++tp]=y,p[y]=1,y=fa[y];
        }
        st[++tp]=y,p[y]=1,y=fa[y],z=n+1,fct=1;
        for(register int i=1;i<=tp;i++)
        {
            c[i]=1;
            for(register int j=last[st[i]];j;j=ed[j].prev)
            {
                if(!p[ed[j].to])
                {
                    c[i]+=sz[ed[j].to];
                }
            }
        }
        for(register int i=1;i<=n;i++)
        {
            fct=(li)fct*i%MOD;
            if(!p[i])
            {
                z=(li)z*sz[i]%MOD;
            }
        }
        fct=(li)fct*(n+1)%MOD;
        for(register int i=1;i<=tp;i++)
        {
            c[i]=(c[i-1]+c[i])%MOD;
        }
        dnc(c,1,tp+1,1);
        for(register int i=1;i<=tp+1;i++)
        {
            pr[i-1]=(li)i*tmpf2[1][i]%MOD;
        }
        memset(tmpf2,0,sizeof(tmpf2)),eval(n+1,tp+1,pr,c,px);
        for(register int i=0;i<=tp;i++)
        {
            cur=(li)fct*qpow((li)px[i]*z%MOD,MOD-2)%MOD;
            res=(res+(((tp-i)&1)?MOD-cur:cur))%MOD;
        }
        printf("%d
    ",(li)res*499122177%MOD);
    }
    
  • 相关阅读:
    对deferred(延迟对象)的理解
    string 、char* 、 char []的转换
    char* 和 cha[]
    层序遍历二叉树
    之字形打印二叉树
    右值
    函数指针(待修改)
    top k

    哈夫曼编码
  • 原文地址:https://www.cnblogs.com/Karry5307/p/13041340.html
Copyright © 2011-2022 走看看