zoukankan      html  css  js  c++  java
  • 【做题】ZJOI2017仙人掌——组合计数

    原文链接 https://www.cnblogs.com/cly-none/p/ZJOI2017cactus.html

    给出一个(n)个点(m)条边的无向连通图,求有多少种加边方案,使得加完后得到一个仙人掌。

    (n leq 5 imes 10^5, m leq 10^6)

    首先,判定无解后,我们可以把每个环删掉,那么答案就是剩下的若干树的加边方案的乘积。

    于是就考虑一棵树怎么做。

    sol1

    (dp_i)表示在结点(i)的子树中的答案。考虑如何转移。

    注意到,假如我们把(i)和它子树中的某个点(v)连一条边,那么这样得到的方案数就与(i)(v)的链有关。对于链上最后一个点即(v),它能产生(dp_v)的贡献;否则,对于链上的点(a),若它在链上的下一个(深度较大的)点是(b),那么(a)能产生的贡献就是(a)删掉(b)这个子树后的dp值。整条链能产生的方案数就是所有贡献的乘积。

    于是我们令(sdp_i)表示在(i)的子树中,以(i)为一段的所有链的方案数之和。这记录的是这个子树的所有向(i)的父亲连边的方案数。(不连边其实就是(i)向外连)

    再考虑一棵子树向外连边有两种情况,一是直接连到当前的根上,二是和另一颗子树配对。

    所以我们令(g_i)表示(i)个元素,每个元素都能任意配对也能不配对(即连到根上)的方案数。那么就有(dp_i = prod_{v in child_i} sdp_v g_{|child_i|})。其中(child_i)表示结点(i)的所有孩子构成的集合。

    考虑如何计算(g_n)。我们可以枚举元素(n)的状态。有两种可能:

    • 不匹配。即(g_{n-1})
    • 匹配。那么枚举它和哪个元素匹配。即((n-1)g_{n-2})

    于是就能计算出所有(g_n)了。

    用类似的方法也能算出(sdp_i)。可以用前缀积和后缀积来避免计算逆元,做到(O(n))复杂度。

    下面代码是(O(n log n))的,(log n)在计算逆元上。

    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    typedef double db;
    typedef pair<int,int> pii;
    #define fir first
    #define sec second
    #define rep(i,a,b) for (int i = (a) ; i <= (b) ; ++ i)
    #define rrp(i,a,b) for (int i = (a) ; i >= (b) ; -- i)
    #define gc() getchar()
    template <typename tp>
    inline void read(tp& x) {
      x = 0; char tmp; bool key = 0;
      for (tmp = gc() ; !isdigit(tmp) ; tmp = gc())
        key = (tmp == '-');
      for ( ; isdigit(tmp) ; tmp = gc())
        x = (x << 3) + (x << 1) + (tmp ^ '0');
      if (key) x = -x;
    }
    
    const int N = 500010, M = 1000010, MOD = 998244353;
    int power(int a,int b) {
      int ret = 1;
      while (b) {
        if (b & 1) ret = 1ll * ret * a % MOD;
        a = 1ll * a * a % MOD;
        b >>= 1;
      }
      return ret;
    }
    inline void Add(int& x,int y) {
      x = x + y >= MOD ? x + y - MOD : x + y;
    }
    struct edge {
      int la,b;
    } con[M << 1];
    int tot,fir[N],n,m,ans;
    void add(int from,int to) {
      con[++tot] = (edge) {fir[from], to};
      fir[from] = tot;
    }
    int dfn[N], low[N], col[N], sta[N], top, ecnt, ccnt;
    pii edg[M];
    void dfs(int pos,int fa) {
      sta[low[pos] = dfn[pos] = ++ top] = pos;
      for (int i = fir[pos] ; i ; i = con[i].la) {
        if (con[i].b == fa) continue;
        if (dfn[con[i].b]) low[pos] = min(low[pos], dfn[con[i].b]);
        else {
          dfs(con[i].b, pos);
          low[pos] = min(low[pos], low[con[i].b]);
          if (low[con[i].b] >= dfn[pos]) {
            if (low[con[i].b] > dfn[pos])
              edg[++ecnt] = pii(pos, con[i].b);
            else ++ ccnt;
            top = dfn[pos];
          }
        }
      }
    }
    int dp[N], sdp[N], sz[N], jc[N], inv[N], vis[N];
    void fsd(int pos,int fa) {
      vis[pos] = 1;
      dp[pos] = 1;
      sz[pos] = 1;
      int num = 0;
      for (int i = fir[pos] ; i ; i = con[i].la) {
        if (con[i].b == fa) continue;
        fsd(con[i].b, pos);
        sz[pos] += sz[con[i].b];
        dp[pos] = 1ll * sdp[con[i].b] * dp[pos] % MOD;
        ++ num;
      }
      int tmp = 0, tmp1 = 0;
      for (int k = 0, ipw2 = 1 ; k * 2 <= num ; ++ k) {
        Add(tmp,1ll * jc[num] * inv[k] % MOD * inv[num - 2 * k] % MOD * ipw2 % MOD);
        ipw2 = 1ll * ipw2 * (MOD + 1) / 2 % MOD;
      }
      -- num;
      for (int k = 0, ipw2 = 1 ; k * 2 <= num ; ++ k) {
        Add(tmp1,1ll * jc[num] * inv[k] % MOD * inv[num - 2 * k] % MOD * ipw2 % MOD);
        ipw2 = 1ll * ipw2 * (MOD + 1) / 2 % MOD;
      }
      for (int i = fir[pos] ; i ; i = con[i].la) {
        if (con[i].b == fa) continue;
        Add(sdp[pos], 1ll * dp[pos] * power(sdp[con[i].b], MOD-2) % MOD * tmp1 % MOD * sdp[con[i].b] % MOD);
      }
      dp[pos] = 1ll * dp[pos] * tmp % MOD;
      Add(sdp[pos], dp[pos]);  
    }
    void init() {
      memset(fir,0,sizeof(int) * (n + 5));
      memset(dfn,0,sizeof(int) * (n + 5));
      tot = ecnt = ccnt = top = 0;
      ans = 1;
      memset(dp,0,sizeof(int) * (n + 5));
      memset(sdp,0,sizeof(int) * (n + 5));
      memset(vis,0,sizeof(int) * (n + 5));
    }
    void solve() {
      read(n), read(m);
      init();
      jc[0] = 1;
      rep (i, 1, n) jc[i] = 1ll * jc[i-1] * i % MOD;
      inv[n] = power(jc[n], MOD - 2);
      rrp (i, n-1, 0) inv[i] = 1ll * inv[i+1] * (i+1) % MOD;
      for (int i = 1, x, y ; i <= m ; ++ i) {
        read(x), read(y);
        add(x,y);
        add(y,x);
      }
      rep (i, 1, n) if (!dfn[i]) dfs(i, 0);
      if (n - 1 + ccnt != m) {
        puts("0");
        return;
      }
      tot = 0;
      memset(fir,0,sizeof(int) * (n + 5));
      rep (i, 1, ecnt) {
        add(edg[i].fir, edg[i].sec);
        add(edg[i].sec, edg[i].fir);
      }
      rep (i, 1, n) {
        if (vis[i]) continue;
        fsd(i, 0);
        ans = 1ll * dp[i] * ans % MOD;
      }
      ans = (ans % MOD + MOD) % MOD;
      printf("%d
    ", ans);
    }
    int main() {
      int T;
      read(T);
      while (T --)
        solve();
      return 0;
    }
    


    sol2

    上面做法对链的分析太复杂了,于是我们考虑直接用组合意义计算(sdp_i)

    分两种情况:

    • 让结点(i)负责向上连边。那么答案就是(prod_{v in child_i} sdp_v g_{|child_i|})
    • (i)的某个孩子内的结点向上连边。我们就枚举这是哪个孩子内的结点。即(|child_i| imes prod_{v in child_i} sdp_v g_{|child_i|-1})

    这样就简单多了。


    小结:这道题的巧妙之处在于对一些组合意义的运用,比较常规,且直接大力分析也能解决。

  • 相关阅读:
    汇编四(习题)
    汇编子程序模块化(near&far)
    win10关闭防火墙
    python中numpy中的shape()的使用
    文件的拷贝linux命令
    python中的os.path.dirname(__file__)
    ubuntu系统下安装及查看opencv版本
    用git命令行克隆项目及出现failed解决方案
    ERROR: Could not install packages due to an EnvironmentError: [Errno 13] Permission denied: '
    记录CenterNet代码编译成功运行
  • 原文地址:https://www.cnblogs.com/cly-none/p/ZJOI2017cactus.html
Copyright © 2011-2022 走看看