zoukankan      html  css  js  c++  java
  • XJTUOJ wmq的A×B Problem FFT/NTT

    wmq的A×B Problem

    发布时间: 2017年4月9日 17:06   最后更新: 2017年4月9日 17:07   时间限制: 3000ms   内存限制: 512M

    这是一个非常简单的问题。

    wmq如今开始学习乘法了!他为了训练自己的乘法计算能力,写出了n个整数,并且对每两个数a,b都求出了它们的乘积a×b。现在他想知道,在求出的n(n1)2个乘积中,除以给定的质数m余数为k(0k<m)的有多少个。

    第一行为测试数据的组数。

    对于每组测试数据,第一行为2个正整数n,m,2n,m60000,分别表示整数的个数以及除数。

    接下来一行有n个整数,满足0ai109

    保证总输出行数m3×105

    对每组数据输出m行,其中第i行为除以m余数为(i1)的有多少个。

    2
    4 5
    2 0 1 7
    4 2
    2 0 1 6
    3
    0
    2
    0
    1
    6
    0
     
    题解:
      m是素数,显然,利用原根性质
      i -> g^i
      此时g为m的原根
      首先 模剩余系 为0~m-1,那么模为 i * j 转化为 g^(i+j)
      即FFT求解,NTT也可过
      最后转化回来即可
    #include<bits/stdc++.h>
    using namespace std;
    #pragma comment(linker, "/STACK:102400000,102400000")
    #define ls i<<1
    #define rs ls | 1
    #define mid ((ll+rr)>>1)
    #define pii pair<int,int>
    #define MP make_pair
    typedef long long LL;
    const long long INF = 1e18+1LL;
    const double pi = acos(-1.0);
    const int N = 333333+10, M = 1e3+20,inf = 2e9,mod = 469762049;
    
    int MOD;
    inline int mul(int a, int b){
        return (long long)a * b % MOD;
    }
    int power(int a, int b){
        int ret = 1;
        for (int t = a; b; b >>= 1){
            if (b & 1)ret = mul(ret, t);
            t = mul(t, t);
        }
        return ret;
    }
    int cal_root(int mod)
    {
        int factor[20], num = 0, s = mod - 1;
        MOD = mod--;
        for (int i = 2; i * i <= s; i++){
            if (s % i == 0){
                factor[num++] = i;
                while (s % i == 0)s /= i;
            }
        }
        if (s != 1)factor[num++] = s;
        for (int i = 2;; i++){
            int j = 0;
            for (; j < num && power(i, mod / factor[j]) != 1; j++);
            if (j == num)return i;
        }
    }
    struct Complex {
        long double r , i ;
        Complex () {}
        Complex ( double r , double i ) : r ( r ) , i ( i ) {}
        Complex operator + ( const Complex& t ) const {
            return Complex ( r + t.r , i + t.i ) ;
        }
        Complex operator - ( const Complex& t ) const {
            return Complex ( r - t.r , i - t.i ) ;
        }
        Complex operator * ( const Complex& t ) const {
            return Complex ( r * t.r - i * t.i , r * t.i + i * t.r ) ;
        }
    } ;
    
    void FFT ( Complex y[] , int n , int rev ) {
        for ( int i = 1 , j , t , k ; i < n ; ++ i ) {
            for ( j = 0 , t = i , k = n >> 1 ; k ; k >>= 1 , t >>= 1 ) j = j << 1 | t & 1 ;
            if ( i < j ) swap ( y[i] , y[j] ) ;
        }
        for ( int s = 2 , ds = 1 ; s <= n ; ds = s , s <<= 1 ) {
            Complex wn = Complex ( cos ( rev * 2 * pi / s ) , sin ( rev * 2 * pi / s ) ) , w ( 1 , 0 ) , t ;
            for ( int k = 0 ; k < ds ; ++ k , w = w * wn ) {
                for ( int i = k ; i < n ; i += s ) {
                    y[i + ds] = y[i] - ( t = w * y[i + ds] ) ;
                    y[i] = y[i] + t ;
                }
            }
        }
        if ( rev == -1 ) for ( int i = 0 ; i < n ; ++ i ) y[i].r /= n ;
    }
    
    Complex s[N];
    int T,n,m,x,num[N];
    LL ans[N],mo[N],fmo[N];
    int main() {
        scanf("%d",&T);
        while(T--) {
            scanf("%d%d",&n,&m);
            int G = cal_root(m);
            for(LL i = 0, t = 1; i < m-1; ++i,t = t*G%m)
                mo[i] = t,fmo[t] = i;
            memset(num,0,sizeof(num));
            LL cnt0 = 0;
            for(int i = 0; i <= m; ++i) ans[i] = 0;
            for(int i = 1; i <= n; ++i) {
                scanf("%d",&x);
                x%=m;
                if(x == 0) cnt0++;
                else {
                    num[fmo[x]]++;
                }
            }
            int n1 = 1;
            for(n1=1;n1<=(2*m-2);n1<<=1);
            for(int i = 0; i < m; ++i) s[i] = Complex(num[i],0);
            for(int i = m; i < n1; ++i) s[i] = Complex(0,0);
            FFT(s,n1,1);
            for(int i = 0; i < n1; ++i) s[i] = s[i]*s[i];
            FFT(s,n1,-1);
            printf("%lld
    ",(LL)cnt0*(n-cnt0)+(LL)cnt0*(cnt0-1)/2);
            for(int i = 0; i <= 2*m-2; ++i) {
                LL now = (LL)(s[i].r+0.5);
                if(i%2==0) now -= num[i/2];
                now/=2;
                ans[mo[i%(m-1)]] += now;
            }
            for(int i = 1; i < m; ++i) printf("%lld
    ",ans[i]);
        }
        return 0;
    }
      
     
    #include<bits/stdc++.h>
    using namespace std;
    #pragma comment(linker, "/STACK:102400000,102400000")
    #define ls i<<1
    #define rs ls | 1
    #define mid ((ll+rr)>>1)
    #define pii pair<int,int>
    #define MP make_pair
    typedef long long LL;
    const long long INF = 1e18+1LL;
    const double pi = acos(-1.0);
    const int N = 533333+10, M = 1e3+20,inf = 2e9;
    
    int MOD;
    inline int mul2(int a, int b){
        return (long long)a * b % MOD;
    }
    int power(int a, int b){
        int ret = 1;
        for (int t = a; b; b >>= 1){
            if (b & 1)ret = mul2(ret, t);
            t = mul2(t, t);
        }
        return ret;
    }
    int cal_root(int mod)
    {
        int factor[26], num = 0, s = mod - 1;
        MOD = mod--;
        for (int i = 2; i * i <= s; i++){
            if (s % i == 0){
                factor[num++] = i;
                while (s % i == 0)s /= i;
            }
        }
        if (s != 1)factor[num++] = s;
        for (int i = 2;; i++){
            int j = 0;
            for (; j < num && power(i, mod / factor[j]) != 1; j++);
            if (j == num)return i;
        }
    }
    
    LL P,G;
    LL mul(LL x,LL y){
        return (x*y-(LL)(x/(long double)P*y+1e-3)*P+P)%P;
    }
    LL qpow(LL x,LL k,LL p){
        LL ret=1;
        while(k){
            if(k&1) ret=mul(ret,x);
            k>>=1;
            x=mul(x,x);
        }
        return ret;
    }
    LL wn[50];
    void getwn(){
        for(int i=1; i<=25; ++i){
            int t=1<<i;
            wn[i]=qpow(G,(P-1)/t,P);
        }
    }
    void NTT_init() {
        P = 3221225473LL,G = 5;
        getwn();
    }
    int len;
    void NTT(LL y[],int op){
        for(int i=1,j=len>>1,k; i<len-1; ++i){
            if(i<j) swap(y[i],y[j]);
            k=len>>1;
            while(j>=k){
                j-=k;
                k>>=1;
            }
            if(j<k) j+=k;
        }
        int id=0;
        for(int h=2; h<=len; h<<=1) {
            ++id;
            for(int i=0; i<len; i+=h){
                LL w=1;
                for(int j=i; j<i+(h>>1); ++j){
                    LL u=y[j],t=mul(y[j+h/2],w);
                    y[j]=u+t;
                    if(y[j]>=P) y[j]-=P;
                    y[j+h/2]=u-t+P;
                    if(y[j+h/2]>=P) y[j+h/2]-=P;
                    w=mul(w,wn[id]);
                }
            }
        }
        if(op==-1){
            for(int i=1; i<len/2; ++i) swap(y[i],y[len-i]);
            LL inv=qpow(len,P-2,P);
            for(int i=0; i<len; ++i) y[i]=mul(y[i],inv);
        }
    }
    LL s[N];
    int T,n,m;
    LL ans[N];
    int num[N],mo[N],fmo[N],root;
    int main() {
        scanf("%d",&T);
        while(T--) {
            scanf("%d%d",&n,&m);
            root = cal_root(m);
            LL cnt0 = 0;
            memset(ans,0,sizeof(ans));
            memset(num,0,sizeof(num));
            for(LL i = 0, t = 1; i < m-1; ++i,t=t*root%m)
                mo[i] = t,fmo[t] = i;
            for(int i = 1; i <= n; ++i) {
                int x;
                scanf("%d",&x);
                x%=m;
                if(x == 0) cnt0++;
                else num[fmo[x]]++;
            }
            for(len = 1; len <= (2*m-2); len<<=1);
            for(int i = 0;i < m; ++i) s[i] = num[i];
            for(int i = m; i < len; ++i) s[i] = 0;
            NTT_init();
            NTT(s,1);
            for(int i = 0; i < len; ++i) s[i] = mul(s[i],s[i]);
            NTT(s,-1);
            for(int i = 0; i <= 2*m-2; ++i) {
                LL now = s[i];
                if(i%2==0) now -= num[i/2];
                now /= 2;
                ans[mo[i%(m-1)]] += now;
            }
            printf("%lld
    ",(LL)cnt0*(n-cnt0)+(LL)cnt0*(cnt0-1)/2);
            for(int i = 1; i < m; ++i) printf("%lld
    ",ans[i]);
        }
        return 0;
    }
  • 相关阅读:
    CentOS6.5 mini安装到VirtualBox虚拟机中
    docker配置redis6.0.5集群
    docker搭建数据库高可用方案PXC
    我通过调试ConcurrentLinkedQueue发现一个IDEA的小虫子(bug), vscode复现, eclipse毫无问题
    ThreadLocal底层原理学习
    第九章
    多线程-java并发编程实战笔记
    Spring-IOC源码解读3-依赖注入
    Spring-IOC源码解读2.3-BeanDefinition的注册
    Spring-IOC源码解读2.2-BeanDefinition的载入和解析过程
  • 原文地址:https://www.cnblogs.com/zxhl/p/7107469.html
Copyright © 2011-2022 走看看