zoukankan      html  css  js  c++  java
  • P3702 [SDOI2017]序列计数 (三模数NTT)

    题意:

    Alice想要得到一个长度为$n$的序列,序列中的数都是不超过$m$的正整数,而且这$n$个数的和是$n$的倍数。

    Alice还希望,这$n$个数中,至少有一个数是质数。

    Alice想知道,有多少个序列满足她的要求。

    思路:
    利用容斥原理,这个题目其实就是

     在集合S(1~m)中选n个数的和模p为0的方案数-在集合S(1~m的合数)中选n个数的和模p为0的方案数(类似模型:洛谷P3321)

    Code:

    #include <map>
    #include <set>
    #include <array>
    #include <queue>
    #include <stack>
    #include <cmath>
    #include <vector>
    #include <cstdio>
    #include <cstring>
    #include <sstream>
    #include <iostream>
    #include <stdlib.h>
    #include <algorithm>
    #include <unordered_map>
    
    using namespace std;
    
    typedef long long ll;
    typedef pair<int, int> PII;
    
    #define sd(a) scanf("%d", &a)
    #define sdd(a, b) scanf("%d%d", &a, &b)
    #define slld(a) scanf("%lld", &a)
    #define slldd(a, b) scanf("%lld%lld", &a, &b)
    #define m1 998244353
    #define m2 469762049
    #define m3 1004535809
    
    const int N = 3e2 + 10;
    const int M = 2e7 + 20;
    const int mod = 20170408;
    const int INF = 0x3f3f3f3f;
    const double PI = acos(-1.0);
    const int Mod[] = {998244353, 469762049, 1004535809};
    
    int n, m, p;
    int rev[N];
    ll vis[N], h[N];
    int primes[M], cnt = 0;
    bool st[M];
    
    void get(int n)
    {
        st[1] = true;
        for (int i = 2; i <= n; i++)
        {
            if (!st[i])
                primes[cnt++] = i;
            for (int j = 0; primes[j] <= n / i; j++)
            {
                st[i * primes[j]] = true;
                if (i % primes[j] == 0)
                {
                    break;
                }
            }
        }
    }
    
    ll qmi(ll a, ll b, ll p)
    {
        ll res = 1;
        while (b)
        {
            if (b & 1)
                res = res * a % p;
            a = a * a % p;
            b >>= 1;
        }
        return res;
    }
    
    void change(ll y[], int len)
    {
        for (int i = 0; i < len; i++)
        {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1)
                rev[i] |= (len >> 1);
        }
    
        for (int i = 0; i < len; i++)
        {
            if (i < rev[i])
                swap(y[i], y[rev[i]]);
        }
    }
    
    void ntt(ll y[], int len, int on, ll MOD)
    {
        change(y, len);
        for (int h = 2; h <= len; h <<= 1)
        {
            ll wn = qmi(3, (MOD - 1) / h, MOD);
            if (on == -1)
                wn = qmi(wn, MOD - 2, MOD);
    
            for (int j = 0; j < len; j += h)
            {
                ll w = 1;
                for (int k = j; k < j + h / 2; k++)
                {
                    ll u = y[k];
                    ll t = w * y[k + h / 2] % MOD;
                    y[k] = (u + t) % MOD;
                    y[k + h / 2] = (u - t + MOD) % MOD;
                    w = w * wn % MOD;
                }
            }
        }
    
        if (on == -1)
        {
            ll inv = qmi(len, MOD - 2, MOD);
            for (int i = 0; i < len; i++)
            {
                y[i] = y[i] * inv % MOD;
            }
        }
    }
    
    ll mult(ll a, ll b, ll p)
    {
        ll res = 0;
        while (b)
        {
            if (b & 1)
                res = (res + a) % p;
            a = (a + a) % p;
            b >>= 1;
        }
        return res;
    }
    
    ll A[N], B[N], C[N], D[N];
    void mul(ll a[], ll b[], ll res[], ll len)
    {
        memcpy(A, a, sizeof(A));
        memcpy(B, a, sizeof(B));
        memcpy(C, a, sizeof(C));
        memcpy(D, b, sizeof(D));
    
        ntt(A, len, 1, Mod[0]);
        ntt(D, len, 1, Mod[0]);
        for (int i = 0; i < len; i++)
        {
            A[i] = A[i] * D[i] % Mod[0];
        }
        ntt(A, len, -1, Mod[0]);
    
        memcpy(D, b, sizeof(D));
        ntt(B, len, 1, Mod[1]);
        ntt(D, len, 1, Mod[1]);
        for (int i = 0; i < len; i++)
        {
            B[i] = B[i] * D[i] % Mod[1];
        }
        ntt(B, len, -1, Mod[1]);
    
        memcpy(D, b, sizeof(D));
        ntt(C, len, 1, Mod[2]);
        ntt(D, len, 1, Mod[2]);
        for (int i = 0; i < len; i++)
        {
            C[i] = C[i] * D[i] % Mod[2];
        }
        ntt(C, len, -1, Mod[2]);
    
        ll M12 = 1ll * m1 * m2;
        ll inv2 = qmi(m2, m1 - 2, m1);
        ll inv1 = qmi(m1, m2 - 2, m2);
        ll mul2 = 1ll * m2 * inv2 % M12;
        ll mul1 = 1ll * m1 * inv1 % M12;
        ll inv = qmi(M12 % m3, m3 - 2, m3);
        ll m12 = M12 % mod;
        ll c1, c2, c3, c4, q;
    
        for (int i = 0; i <= (p << 1); i++)
        {
            c1 = A[i], c2 = B[i], c3 = C[i];
            c4 = (mult(c1, mul2, M12) + mult(c2, mul1, M12)) % M12;
            q = ((c3 - c4) % m3 + m3) % m3 * inv % m3;
            res[i] = (q * m12 % mod + c4) % mod;
        }
        for (int i = p; i < len; i++)
        {
            res[i % p] = (res[i % p] + res[i]) % mod;
            res[i] = 0;
        }
    }
    
    ll res[N];
    
    void qmi_ntt(ll y[], int len, int n)
    {
        memset(res, 0, sizeof(res));
        res[0] = 1;
        while (n)
        {
            if (n & 1)
            {
                mul(res, y, res, len);
            }
            mul(y, y, y, len);
            n >>= 1;
        }
    }
    
    ll mid[N], ans[3], ans1, ans2;
    
    void solve()
    {
        cin >> n >> m >> p;
    
        get(m);
        for (int i = 1; i <= m; i++)
        {
            vis[i % p]++;
            if (st[i])
                h[i % p]++;
        }
    
        int len = 1;
        while (len <= p + p - 1)
            len <<= 1;
    
        qmi_ntt(vis, len, n);
    
        ans1 = res[0];
    
        qmi_ntt(h, len, n);
    
        ans1 = (ans1 - res[0] + mod) % mod;
        cout << ans1 << "
    ";
    }
    
    int main()
    {
    #ifdef ONLINE_JUDGE
    #else
        freopen("/home/jungu/code/in.txt", "r", stdin);
        // freopen("/home/jungu/桌面/11.21/2/in9.txt", "r", stdin);
    #endif
        ios::sync_with_stdio(false);
        cin.tie(0), cout.tie(0);
    
        int T = 1;
        // sd(T);
        // cin >> T;
        while (T--)
        {
            solve();
        }
    
        return 0;
    }
  • 相关阅读:
    android之PackageManager简单介绍
    西门子PLC学习笔记二-(工作记录)
    node.js第十课(HTTPserver)
    ubuntu ???????????? no permissions 问题解决
    Web API 设计摘要
    公共 DNS server IP 地址
    用Unicode迎接未来
    vs2010公布时去除msvcp100.dll和msvcr100.dll图讲解明
    linux串口驱动分析
    C++ 中dynamic_cast&lt;&gt;的用法
  • 原文地址:https://www.cnblogs.com/jungu/p/14416333.html
Copyright © 2011-2022 走看看