zoukankan      html  css  js  c++  java
  • 完全图的最小生成树计数

    题意

    给定一个长度为 (n) 的数组 (a_1,a_2,dots ,a_n),有一幅完全图,满足 ((u,v)) 的边权为 (a_u ext{xor} a_v) 。求边权和最小的生成树,你需要输出边权和以及方案数对 (1e9+7) 取模的值(边权和不要取模)。

    (1leq n leq 10^5,0leq a_i <2^{30})

    题目链接:https://vjudge.net/problem/51Nod-1601

    分析

    求边权和直接按照异或最小生成树的模板求。

    求方案数时,当两个联通块相连,如果相连的最小边权的边存在多条,那么按照乘法原理,应该乘到答案中。同时,由于在字典树中,点权相同的点位于同一个点,这些点之间也要连边。可以转化为完全图的生成树个数,根据 ( ext{prufer}) 序列,答案为 (n^{n-2}) 个。注意,当点是一个重复的点,如果没有计算,要上传到其父亲节点,直到计算为止。

    代码

    #include <bits/stdc++.h>
    #define pb push_back
    using namespace std;
    typedef long long ll;
    const int mod=1e9+7;
    const int N=1e5+5;
    const int maxn=2e6+5;
    int trie[maxn][2],a[N],cnt;
    int id[maxn],num[maxn],n;
    ll ans,cot;
    vector<int>value[N];
    ll power(ll x,ll y)
    {
        ll res=1;
        x%=mod;
        while(y)
        {
            if(y&1) res=res*x%mod;
            x=x*x%mod;
            y>>=1;
        }
        return res;
    }
    void add(int x,int k)
    {
        int rt=1;
        for(int i=29;i>=0;i--)
        {
            int t=((x>>i)&1);
            if(trie[rt][t]==0)
                trie[rt][t]=++cnt;
            rt=trie[rt][t];
        }
        id[rt]=k;
        value[k].pb(x);
        num[rt]+=(upper_bound(a+1,a+1+n,x)-lower_bound(a+1,a+1+n,x));
        //cout<<"->"<<num[rt]<<endl;
    }
    int matching(int x,int rt,int d)
    {
        int res=(1<<d);
        for(int i=d-1;i>=0;i--)
        {
            int t=((x>>i)&1);
            if(trie[rt][t]>0)
                rt=trie[rt][t];
            else
            {
                rt=trie[rt][1-t];
                res|=(1<<i);
            }
        }
        return res;
    }
    void solve(int rt,int d)
    {
        if(trie[rt][0]>0) solve(trie[rt][0],d-1);
        if(trie[rt][1]>0) solve(trie[rt][1],d-1);
        if(trie[rt][0]>0&&trie[rt][1]>0)
        {
            int min_xor=(1<<30);
            int x=id[trie[rt][0]],y=id[trie[rt][1]];
            ll w=0;
            ll u=1,v=1;
            if(num[trie[rt][0]]>1) u=power(num[trie[rt][0]],num[trie[rt][0]]-2);
            if(num[trie[rt][1]]>1) v=power(num[trie[rt][1]],num[trie[rt][1]]-2);
            //cout<<"d="<<d<<" u="<<u<<" v="<<v<<endl;
            if(value[x].size()<value[y].size())
            {
                for(int i=0;i<value[x].size();i++)
                {
                    int tmp=value[x][i];
                    int val=matching(tmp,trie[rt][1],d-1);
                    int tn=upper_bound(a+1,a+1+n,tmp)-lower_bound(a+1,a+1+n,tmp);
                    int ct=upper_bound(a+1,a+1+n,(tmp^val))-lower_bound(a+1,a+1+n,(tmp^val));
                    ll res=1LL*tn*ct%mod*u%mod*v%mod;
                    if(val<min_xor)
                        min_xor=val,w=res;
                    else if(val==min_xor)
                        w=(w+res)%mod;
                    value[y].pb(tmp);
                }
                id[rt]=y;
            }
            else
            {
                for(int i=0;i<value[y].size();i++)
                {
                    int tmp=value[y][i];
                    int val=matching(tmp,trie[rt][0],d-1);
                    int tn=upper_bound(a+1,a+1+n,tmp)-lower_bound(a+1,a+1+n,tmp);
                    int ct=upper_bound(a+1,a+1+n,(tmp^val))-lower_bound(a+1,a+1+n,(tmp^val));
                    ll res=1LL*tn*ct%mod*u%mod*v%mod;
                    if(val<min_xor)
                        min_xor=val,w=res;
                    else if(val==min_xor)
                        w=(w+res)%mod;
                    value[x].pb(tmp);
                }
                id[rt]=x;
            }
            ans+=min_xor;
            cot=cot*w%mod;
        }
        else
        {
            if(trie[rt][0]>0||trie[rt][1]>0)
            {
                id[rt]=id[trie[rt][0]+trie[rt][1]];
                num[rt]=num[trie[rt][0]]+num[trie[rt][1]];
            }
        }
    }
    int main()
    {
        scanf("%d",&n);
        for(int i=1;i<=n;i++)
            scanf("%d",&a[i]);
        sort(a+1,a+1+n);
        cnt=1;
        add(a[1],1);
        for(int i=2;i<=n;i++)
        {
            if(a[i]!=a[i-1])
                add(a[i],i);
        }
        ans=0,cot=1;
        solve(1,30);
        if(num[1]>1) cot=cot*power(num[1],num[1]-2)%mod;
        printf("%lld
    %lld
    ",ans,cot);
        return 0;
    }
    
    
  • 相关阅读:
    PL/SQL Developer保存自定义界面布局
    SQL Server 2008中SQL增强之二:Top新用途
    泛型和集合
    Go语言
    软件架构师培训
    using的几种用法
    【十五分钟Talkshow】如何善用你的.NET开发环境
    心的感谢
    【缅怀妈妈系列诗歌】之四:妈妈,对不起
    PDA开发经验小结 (转共享)
  • 原文地址:https://www.cnblogs.com/1024-xzx/p/13646273.html
Copyright © 2011-2022 走看看