zoukankan      html  css  js  c++  java
  • 2017 CCPC 杭州 HDU 6270 Marriage (NTT,容斥)

    题目:传送门

    题意

    有 n 个家庭,每个家庭有 ai 个男孩和 bi 个女孩,n 个家庭总的男孩等于总的女孩。对于来自 i 家庭的男孩他只能和不来自 i 家庭的女孩结婚,也就是来自同个家庭的男孩女孩不能结婚。问有多少种方案,使得这 这些男孩女孩都能成功结婚。

    思路

    参考博客:

    对于一个有 x 个男孩和 y 个女孩的家庭来说,有且仅有 k 对来自这个家庭的男孩女孩结婚(近亲结婚)的方案数是:

    C(x, k) * C(y,k) * k!

    那么如果在第一个家庭选 k1 对近亲结婚,第二个家庭 k2 对......第 n 个家庭 kn 对,剩下的自由组合,最后这种方案至少有 k1+k2...+kn 对近亲结婚。

    那我们对每个家庭构造一个多项式:

    c0 + c1*x + c2*x^2 + .... + cm*x^m  (m = min(x, y))

    把这 n 个多项式乘起来,得到的多项式的 x^k 的系数 ck 代表的就是至少有 k 对近亲结婚的方案数。

    因为代表的是至少,所以最后还需要容斥一下。

    n个多项式相乘,复杂度跟多项式的长度有很大关系,n个多项式的长度就是所有男孩的总数;所以复杂度其实是 o(nlognlogn)的

    #include <bits/stdc++.h>
    #define LL long long
    #define ULL unsigned long long
    #define UI unsigned int
    #define mem(i, j) memset(i, j, sizeof(i))
    #define rep(i, j, k) for(int i = j; i <= k; i++)
    #define dep(i, j, k) for(int i = k; i >= j; i--)
    #define pb push_back
    #define make make_pair
    #define INF 0x3f3f3f3f
    #define inf LLONG_MAX
    #define PI acos(-1)
    #define fir first
    #define sec second
    #define lb(x) ((x) & (-(x)))
    #define dbg(x) cout<<#x<<" = "<<x<<endl;
    using namespace std;
    
    const int N = 1e6 + 5;
    const LL mod = 998244353;
    const LL g = 3;
    
    int n, all, cnt;
    LL fac[N];
    LL x1[N], x2[N];
    vector < LL > a[N];
    
    LL ksm(LL a, LL b) {
        LL res = 1LL;
        while(b) {
            if(b & 1) res = res * a % mod;
            a = a * a % mod; b >>= 1;
        }
        return res;
    }
    
    LL C(int n, int m) { return m > n ? 0 : fac[n] * ksm(fac[m] * fac[n - m] % mod, mod - 2) % mod; }
    
    struct cmp{
        bool operator()(int A, int B) {
            return a[A].size() > a[B].size();
        }
    };
    priority_queue <LL, vector<LL>, cmp> Q;
    
    void change(LL y[], int len){
        for (int i = 1, j = len / 2; i < len - 1; i++){
            if (i < j) swap(y[i], y[j]);
            int k = len / 2;
            while (j >= k){
                j -= k;
                k /= 2;
            }
            if (j < k) j += k;
        }
    }
    
    void ntt(LL y[], int len, int on){
        change(y, len);
        for (int h = 2; h <= len; h <<= 1){
            LL wn = ksm(g, (mod - 1) / h);
            if (on == -1) wn = ksm(wn, mod - 2);
            for (int j = 0; j < len; j += h){
                LL w = 1ll;
                for (int k = j; k < j + h / 2; k++){
                    LL u = y[k];
                    LL t = w * y[k + h / 2] % mod;
                    y[k] = (u + t) % mod;
                    y[k + h / 2] = (u - t + mod) % mod;
                    w = w * wn % mod;
                }
            }
        }
    
        if (on == -1){
            LL t = ksm(len, mod - 2);
            rep(i, 0, len - 1) y[i] = y[i] * t % mod;
        }
    }
    
    void mul(vector <LL> &a, vector <LL> &b, vector <LL> &c){
        int len = 1;
        int sz1 = a.size(), sz2 = b.size();
    
        while (len <= sz1 + sz2 - 1) len <<= 1;
    
        rep(i, 0, sz1 - 1) x1[i] = a[i];
        rep(i, sz1, len)   x1[i] = 0;
    
        rep(i, 0, sz2 - 1) x2[i] = b[i];
        rep(i, sz2, len)   x2[i] = 0;
    
        ntt(x1, len, 1);
        ntt(x2, len, 1);
    
        rep(i, 0, len - 1) x1[i] = x1[i] * x2[i];
    
        ntt(x1, len, -1);
    
        vector <LL>().swap(c);
        rep(i, 0, sz1 + sz2 - 2) c.push_back(x1[i]);
    }
    
    void solve() {
    
        scanf("%d", &n);
        rep(i, 0, n) vector < LL >().swap(a[i]);
        while(!Q.empty()) Q.pop();
    
        all = 0;
    
        rep(i, 1, n) {
            int x, y;
            scanf("%d %d", &x, &y);
            a[i].resize(min(x,y)+1);
            rep(j, 0, min(x,y)) a[i][j] = C(x, j) * C(y, j) % mod * fac[j] % mod;
            Q.push(i);  all += x;
        }
    
        cnt = n;
    
        rep(i, 1, n - 1) {
            int pos1 = Q.top(); Q.pop();
            int pos2 = Q.top(); Q.pop();
    
            mul(a[pos1], a[pos2], a[++cnt]);
    
            vector < LL >().swap(a[pos1]);
            vector < LL >().swap(a[pos2]);
    
            Q.push(cnt);
        }
    
        LL ans = 0LL, flag = 1LL;
    
        rep(i, 0, (int)(a[cnt].size()) - 1) {
            ans = ans + flag * fac[all - i] * a[cnt][i] % mod;
            ans = (ans + mod) % mod;
            flag *= -1;
        }
        printf("%lld
    ", ans);
    }
    
    
    int main() {
    
        fac[0] = 1LL; rep(i, 1, N - 5) fac[i] = 1LL * i * fac[i - 1] % mod;
    
        int _; scanf("%d", &_);
        while(_--) solve();
    
    //    solve();
    
        return 0;
    }
  • 相关阅读:
    Spring中配置和读取多个Properties文件
    python 数据清洗
    python excel 文件合并
    Pandas -- Merge,join and concatenate
    python 数据合并
    python pandas
    python Numpy
    EXCEL 导入 R 的几种方法 R—readr和readxl包
    R语言笔记完整版
    第十三章 多项式回归分析
  • 原文地址:https://www.cnblogs.com/Willems/p/13839974.html
Copyright © 2011-2022 走看看