zoukankan      html  css  js  c++  java
  • 树位DP

      给定一棵n个节点的树,和一个n的排列(b[i]),求树的DFS序中严格小于给定排列的方案数。n<=1e6

      这是一道。。树位DP题,我们沿用数位DP的思想,逐位确定。

      首先我们考虑最没有限制的情况,如果一个以x为根的树不受限制,它的DFS序有多少种。

      这个显然可以换根DP。先进行子树DP。设$f[x]$为答案,那么可得$f[x]=son[x]! imes prod f[son]$其中son[x]表示儿子个数。

      这个可以理解,就是表示从目前到遍历完子树的DFS序可以分成一段一段的,每一段是一个儿子的DFS,然后段与段是排列的关系,因为没有限制。

      然后我们就可以再次换根DP找到以x为根的整棵树的答案。

      枚举1到b[1]-1把他们的f加入答案。

      然后考虑以b[1]为根。

      我们现在用一个dfs来解决问题,问题是,以b序列中一个值为根,子树严格小于的方案数。

      那么进入dfs,我们目的是甩锅给下一层,然后递归解决。但是有的东西不能甩,本层必须解决。

      设当前位是len

      先分一下类。

      1.从len+1就小于:这种问题本层即可解决,找到接下来第一段的可能情况,也就是有多少儿子<b[len+1],然后假设为cnt,那么第一段有cnt种情况,剩下的仍然是排列和累乘。

      2.在某个儿子的子树中开始小于,这是一个递归的问题一会再说。

      3.从某个儿子开始不等,比如说前两棵子树都恰好覆盖了一段树上序列,然后接下来选一个小于b序列当前位的儿子作为下一位,那么应该是总儿子数减去已经让它完全覆盖的儿子数,这样得到了剩下可以选的儿子数,然后我在找到可以选的中所有小于b当前位的儿子数,还是第一位的情况*剩下的排列和累乘。

      那么我们考虑顺着b数组来捋,解决第二个问题。

      我们进行一个儿子次的循环。

      循环内部每次找到一个儿子等于b[late],late为上一个儿子覆盖完后到的b序列的位置,第一个则为len+1,相当于给挨个拿儿子往b序列上贴,接下来我们找到了一个和当前问题一样的问题,找到这个儿子在限制下的排列数,果断甩锅,当一个儿子的子树不能完全贴到b上,break。

      但是我们遇到了一个问题,怎么判断这个儿子的子树能不能把子树的size个点全贴到b上呢? 

      我们就需要用一个东西来记录这个儿子的子树是否能够吧子树个size全贴到b上,发现这个也是可以递归解决的,

      用dfs返回结构体也好,全局变量修改也罢,总之我的dfs要返回一个flag,表示能不能全贴上,这个flag是1必须是所有的儿子都能按顺序贴到b上,即儿子的flag都是1,具体实现就是我之前的儿子次循环真的进行了儿子次,并没有从中间break掉。当然循环中如果找不到一个儿子等于b[late]也要break。

      然后在顺一下思路:分三类,第一类可以一进dfs就算完,第二类是通过枚举儿子是否等于b[late]并dfs判断能不能全部贴到序列上,如果能,我累加儿子子树中开始严格小于的答案,然后接着吧late+=size[son],相当于把这个儿子贴到序列上,然后累加一下第三类答案也就是从下一个儿子处开始小于b的答案,接着找下一个儿子等于b[late]的子树中小于的答案……直至循环结束。

      交叉着进行二三类答案的计算。

      当递归到叶子节点时,处理flag,如果我的值等于b的当前值为1,否则为0,然后就是返回值,如果我的值小于b[len]那么返回1,表示递归的一条链是可以严格小于的。

      然后问题就解决了。

      然鹅会T。

      观察一下数据范围,1e6,但是在dfs的过程中是对于每个点我进行了两层循环,也就是说每个点被作为儿子枚举了n次,复杂度是$O(n^{2})$的,复杂度瓶颈卡在了我在找一个儿子是否等于b[late]的时候是枚举所有儿子的,接下来就很简单了,用一个数据结构维护一下每个点的儿子,支持删除,和单点,区间查询,splay和sgtree都可以,但我觉得动态开点的sgtree好写(得多)。然后就能A了。

      

    #include<cstring>
    #include<iostream>
    #include<cstdio>
    #define ll long long
    using namespace std;
    const int N=300020,mod=1e9+7;
    int fr[N],b[N],son[N],size[N],fa[N],flag,tt,n;
    ll f[N],fac[N];
    bool v[N];
    struct node{int to,pr;}mo[N*2];
    long long rd()
    {
        long long s=0,w=1;
        char cc=getchar();
        while(cc<'0'||cc>'9') {if(cc=='-') w=-1;cc=getchar();}
        while(cc>='0'&&cc<='9') s=(s<<3)+(s<<1)+cc-'0',cc=getchar();
        return s*w;
    }
    ll inv(ll a)
    {
        ll ans=1,k=mod-2;
        for(;k;k>>=1,a=1ll*a*a%mod) if(k&1) ans=1ll*ans*a%mod;
        return ans;
    }
    void add(int x,int y)
    {
        mo[++tt].to=y;
        mo[tt].pr=fr[x];
        fr[x]=tt;
    }
    void first_dfs(int x)
    {
        ll ans=1;size[x]=1;
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x])continue;
            son[x]++;
            fa[to]=x;
            first_dfs(to);
            ans=1ll*ans*f[to]%mod;
            size[x]+=size[to];
        }
        f[x]=1ll*fac[son[x]]*ans%mod;
    }
    void re_dfs(int x)
    {
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x]) continue;
            //cout<<x<<" "<<to<<endl;
            //cout<<to<<" "<<1ll*f[x]%mod*inv[f[to]]%mod<<endl;
            //cout<<son[to]<<endl;
            f[to]=1ll*f[to]*f[x]%mod*inv(f[to])%mod*inv(son[x])%mod*(++son[to])%mod;
            re_dfs(to);
        }
    }
    ll dfs(int len,int x)
    {
        if(son[x]==0) 
        {
            flag=x==b[len];
            return x<b[len];
        }
        long long ans=0,sum=0;
        for(int i=fr[x];i;i=mo[i].pr)
            if(mo[i].to!=fa[x]&&mo[i].to<b[len+1])sum++;
        //cout<<ans<<endl;
        ans=(ans+1ll*sum*f[x]%mod*inv(son[x])%mod)%mod;
        //cout<<x<<" "<<ans<<endl;
        ll lat=len+1,pi=1ll*f[x]*inv(fac[son[x]])%mod;sum=son[x];
        //cout<<x<<" "<<" "<<lat<<endl;
        flag=1;
        for(int k=1;k<=son[x];k++)
        {
            bool jud=0;
            for(int i=fr[x];i;i=mo[i].pr)
            {
                int to=mo[i].to;
                if(to==fa[x]) continue;
                //cout<<to<<" "<<b[lat]<<endl;
                if(to==b[lat])
                {
                    v[to]=1;
                    pi=pi*inv(f[to])%mod;
                    long long tmp=dfs(lat,to);
                    ans=(ans+1ll*tmp*fac[sum-1]%mod*pi%mod)%mod;
                //    cout<<to<<" "<<b[lat]<<" "<<tmp<<" "<<flag<<endl;
                    lat=lat+size[to],sum--;
                    jud=1;
                    break;
                }
            }
            if(!flag) break;
            if(!jud) break;
            int cnt=0;
            for(int i=fr[x];i;i=mo[i].pr)
            {
                int to=mo[i].to;
                if(to==fa[x]) continue;
                if(v[to]) continue;
                if(to<b[lat]) cnt++;
            }
            //cout<<b[lat]<<" s"<<cnt<<" "<<sum<<" "<<pi<<" "<<ans<<endl;
            if(sum!=son[x])ans=(ans+1ll*pi*cnt%mod*fac[sum-1]%mod)%mod;
            //cout<<ans<<endl;
            
        }
        if(flag==1&&sum==0) flag=1;
        else flag=0;
        return ans;
    }
    ll solve()
    {
        ll ans=0;
        first_dfs(1);
        re_dfs(1);
        for(int i=1;i<b[1];i++)ans=(ans+f[i])%mod;
        //cout<<ans<<endl;
        memset(son,0,sizeof(son));
        memset(fa,0,sizeof(fa));
        memset(f,0,sizeof(f));
        memset(size,0,sizeof(size));
        first_dfs(b[1]);
        ans=(ans+dfs(1,b[1]))%mod;
        return ans;
    }
    int main()
    {
        //freopen("travel2.in","r",stdin);
        //freopen("data1.in","r",stdin);
        //freopen("data1.out","w",stdout);
        n=rd();fac[0]=1;
        for(int i=1;i<=n;i++)b[i]=rd(),fac[i]=1ll*fac[i-1]*i%mod;
        for(int i=1,x,y;i<n;i++)
        {
            x=rd(),y=rd();
            add(x,y);add(y,x);
        }
        printf("%lld
    ",solve());
    }
    /*
    g++ -std=c++11 1.cpp -o 1
    ./1
    6
    1 3 6 2 5 4 
    1 2
    1 3
    1 4
    4 5
    1 6
    */
    80pts更容易理解
    #include<cstring>
    #include<iostream>
    #include<cstdio>
    #define ll long long
    using namespace std;
    const int N=300020,mod=1e9+7;
    int fr[N],b[N],son[N],size[N],fa[N],flag,tt,n;
    ll f[N],fac[N];
    bool v[N];
    struct node{int to,pr;}mo[N*2];
    long long rd()
    {
        long long s=0,w=1;
        char cc=getchar();
        while(cc<'0'||cc>'9') {if(cc=='-') w=-1;cc=getchar();}
        while(cc>='0'&&cc<='9') s=(s<<3)+(s<<1)+cc-'0',cc=getchar();
        return s*w;
    }
    ll inv(ll a)
    {
        ll ans=1,k=mod-2;
        for(;k;k>>=1,a=1ll*a*a%mod) if(k&1) ans=1ll*ans*a%mod;
        return ans;
    }
    void add(int x,int y)
    {
        mo[++tt].to=y;
        mo[tt].pr=fr[x];
        fr[x]=tt;
    }
    void first_dfs(int x)
    {
        ll ans=1;size[x]=1;
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x])continue;
            son[x]++;
            fa[to]=x;
            first_dfs(to);
            ans=1ll*ans*f[to]%mod;
            size[x]+=size[to];
        }
        f[x]=1ll*fac[son[x]]*ans%mod;
    }
    void re_dfs(int x)
    {
        for(int i=fr[x];i;i=mo[i].pr)
        {
            int to=mo[i].to;
            if(to==fa[x]) continue;
            //cout<<x<<" "<<to<<endl;
            //cout<<to<<" "<<1ll*f[x]%mod*inv[f[to]]%mod<<endl;
            //cout<<son[to]<<endl;
            f[to]=1ll*f[to]*f[x]%mod*inv(f[to])%mod*inv(son[x])%mod*(++son[to])%mod;
            re_dfs(to);
        }
    }
    ll dfs(int len,int x)
    {
        if(son[x]==0) 
        {
            flag=x==b[len];
            return x<b[len];
        }
        long long ans=0,sum=0;
        for(int i=fr[x];i;i=mo[i].pr)
            if(mo[i].to!=fa[x]&&mo[i].to<b[len+1])sum++;
        //cout<<ans<<endl;
        ans=(ans+1ll*sum*f[x]%mod*inv(son[x])%mod)%mod;
        //cout<<x<<" "<<ans<<endl;
        ll lat=len+1,pi=1ll*f[x]*inv(fac[son[x]])%mod;sum=son[x];
        //cout<<x<<" "<<" "<<lat<<endl;
        flag=1;
        for(int k=1;k<=son[x];k++)
        {
            bool jud=0;
            for(int i=fr[x];i;i=mo[i].pr)
            {
                int to=mo[i].to;
                if(to==fa[x]) continue;
                //cout<<to<<" "<<b[lat]<<endl;
                if(to==b[lat])
                {
                    v[to]=1;
                    pi=pi*inv(f[to])%mod;
                    long long tmp=dfs(lat,to);
                    ans=(ans+1ll*tmp*fac[sum-1]%mod*pi%mod)%mod;
                //    cout<<to<<" "<<b[lat]<<" "<<tmp<<" "<<flag<<endl;
                    lat=lat+size[to],sum--;
                    jud=1;
                    break;
                }
            }
            if(!flag) break;
            if(!jud) break;
            int cnt=0;
            for(int i=fr[x];i;i=mo[i].pr)
            {
                int to=mo[i].to;
                if(to==fa[x]) continue;
                if(v[to]) continue;
                if(to<b[lat]) cnt++;
            }
            //cout<<b[lat]<<" s"<<cnt<<" "<<sum<<" "<<pi<<" "<<ans<<endl;
            if(sum!=son[x])ans=(ans+1ll*pi*cnt%mod*fac[sum-1]%mod)%mod;
            //cout<<ans<<endl;
            
        }
        if(flag==1&&sum==0) flag=1;
        else flag=0;
        return ans;
    }
    ll solve()
    {
        ll ans=0;
        first_dfs(1);
        re_dfs(1);
        for(int i=1;i<b[1];i++)ans=(ans+f[i])%mod;
        //cout<<ans<<endl;
        memset(son,0,sizeof(son));
        memset(fa,0,sizeof(fa));
        memset(f,0,sizeof(f));
        memset(size,0,sizeof(size));
        first_dfs(b[1]);
        ans=(ans+dfs(1,b[1]))%mod;
        return ans;
    }
    int main()
    {
        //freopen("travel2.in","r",stdin);
        //freopen("data1.in","r",stdin);
        //freopen("data1.out","w",stdout);
        n=rd();fac[0]=1;
        for(int i=1;i<=n;i++)b[i]=rd(),fac[i]=1ll*fac[i-1]*i%mod;
        for(int i=1,x,y;i<n;i++)
        {
            x=rd(),y=rd();
            add(x,y);add(y,x);
        }
        printf("%lld
    ",solve());
    }
    /*
    g++ -std=c++11 1.cpp -o 1
    ./1
    6
    1 3 6 2 5 4 
    1 2
    1 3
    1 4
    4 5
    1 6
    */
    100pts
  • 相关阅读:
    vijos1746 floyd
    总结
    用javascript代码拼html
    异步编程学习
    SELECT
    设计 Azure SQL 数据库,并使用 C# 和 ADO.NET 进行连接
    H2数据库
    ASP.NET 文档
    ASP.NET MVC
    ASP.NET Core 中的 Razor 页面介绍
  • 原文地址:https://www.cnblogs.com/starsing/p/11401623.html
Copyright © 2011-2022 走看看