zoukankan      html  css  js  c++  java
  • 听说这个FFT跑得巨jb快

    #pragma GCC target ("avx2")
    
    #include <immintrin.h>
    
    #include<bits/stdc++.h>
    
    using namespace std;
    
    typedef unsigned char U8;
    typedef int I32;
    typedef unsigned int U32;
    typedef long long I64;
    typedef unsigned long long U64;
    
    const int P=998244353;
    const int inv2=(P+1)>>1;
    const int G=3;
    const int GInv=332748118;
    const int MaxExp=22;
    const int MAXN=1048576<<1;
    const int MAXK=500005;
    const int BmM=288737297;
    const int BmW=29;
    
    struct FastIO{
        static const int S=1e7;
        int wpos;
        char wbuf[S];
        FastIO():wpos(0){}
        inline int xchar()
        {
            static char buf[S];
            static int len=0,pos=0;
            if(pos==len)
                pos=0,len=fread(buf,1,S,stdin);
            if(pos==len) exit(0);
            return buf[pos++];
        }
        inline int operator () ()
        {
            int c=xchar(),x=0;
            while (c<=32) c=xchar();
            for (;'0'<=c&&c<='9';c=xchar()) x=x*10+c-'0';
            return x;
        }
        inline I64 operator ! ()
        {
            int c=xchar();I64 x=0;
            while (c<=32) c=xchar();
            for (; '0'<=c&&c<='9';c=xchar()) x=x*10+c-'0';
            return x;
        }
        inline void wchar(int x)
        {
            if (wpos==S) fwrite(wbuf,1,S,stdout),wpos=0;
            wbuf[wpos++]=x;
        }
        inline void operator () (I64 x)
        {
            if (x < 0) wchar('-'),x=-x;
            char s[24];
            int n=0;
            while(x||!n) s[n++]='0'+x%10,x/=10;
            while(n--) wchar(s[n]);
            wchar('
    ');
        }
        inline void space(I64 x)
        {
            if (x<0) wchar('-'),x=-x;
            char s[24];
            int n=0;
            while(x||!n) s[n++]='0'+x%10,x/=10;
            while(n--) wchar(s[n]);
            wchar(' ');
        }
        inline void nextline()
        {
            wchar('
    ');
        }
        ~FastIO()
        {
            if (wpos) fwrite(wbuf, 1, wpos, stdout), wpos = 0;
        }
    }io;
    
    U32 GcdEx(U32 A, U32 B, I32& x, I32& y) {
        if (!B) {
            x = 1;
            y = 0;
            return A;
        }
        U32 d = GcdEx(B, A % B, y, x);
        y -= x * (I32) (A / B);
        return d;
    }
    
    inline U32 MNorm(I32 V) {
        V = V % P;
        return (U32) (V < 0 ? V + P : V);
    }
    
    inline U32 MAdd(U32 A, U32 B) {
        U32 res = A + B;
        return res < P ? res : res - P;
    }
    
    inline U32 MSub(U32 A, U32 B) {
        U32 res = A - B;
        return A < B ? res + P : res;
    }
    
    inline U32 MMul(U32 A, U32 B) {
        return (U32) ((U64) A * B % P);
    }
    
    inline U32 MPow(U32 A, U32 B) {
        U32 res = 1;
        while (B) {
            if (B & 1)
                res = MMul(res, A);
            A = MMul(A, A);
            B >>= 1;
        }
        return res;
    }
    
    inline U32 MInv(U32 N) {
        I32 x, y;
        GcdEx(N, P, x, y);
        x %= P;
        return (U32) (x < 0 ? x + P : x);
    }
    
    inline __m256i VLod(const U32* __restrict__ A) {
        return _mm256_load_si256((const __m256i*) A);
    }
    
    inline void VSto(U32* __restrict__ A, __m256i v) {
        _mm256_store_si256((__m256i*) A, v);
    }
    
    inline __m256i VEx0(__m256i v) {
        const __m256i vm0 = _mm256_set_epi64x(
            0x111111111b1a1918, 0x1111111113121110,
            0x111111110b0a0908, 0x1111111103020100
        );
        return _mm256_shuffle_epi8(v, vm0);
    }
    
    inline __m256i VEx1(__m256i v) {
        const __m256i vm1 = _mm256_set_epi64x(
            0x111111111f1e1d1c, 0x1111111117161514,
            0x111111110f0e0d0c, 0x1111111107060504
        );
        return _mm256_shuffle_epi8(v, vm1);
    }
    
    inline __m256i VIntlv(__m256i v0, __m256i v1) {
        return _mm256_blend_epi32(v0, _mm256_shuffle_epi32(v1, 0xb1), 0xaa);
    }
    
    inline __m256i VAdd(__m256i va, __m256i vb) {
        const __m256i vm32 = _mm256_set1_epi32(P);
        __m256i vra = _mm256_add_epi32(va, vb);
        __m256i vrb = _mm256_sub_epi32(vra, vm32);
        return _mm256_min_epu32(vra, vrb);
    }
    
    inline __m256i VSub(__m256i va, __m256i vb) {
        const __m256i vm32 = _mm256_set1_epi32(P);
        __m256i vra = _mm256_sub_epi32(va, vb);
        __m256i vrb = _mm256_add_epi32(vra, vm32);
        return _mm256_min_epu32(vra, vrb);
    }
    
    inline __m256i VMul(__m256i va0, __m256i va1, __m256i vb0, __m256i vb1) {
        const __m256i vm32 = _mm256_set1_epi32(P);
        const __m256i vm64 = _mm256_set1_epi64x(P);
        const __m256i vbmm = _mm256_set1_epi64x(BmM);
        __m256i vmul0 = _mm256_mul_epi32(va0, vb0);
        __m256i vmul1 = _mm256_mul_epi32(va1, vb1);
        __m256i vlow = VIntlv(vmul0, vmul1);
        __m256i vquo0 = _mm256_srli_epi64(_mm256_mul_epi32(_mm256_srli_epi64(vmul0, 29), vbmm), BmW);
        __m256i vquo1 = _mm256_srli_epi64(_mm256_mul_epi32(_mm256_srli_epi64(vmul1, 29), vbmm), BmW);
        __m256i vval0 = _mm256_mul_epi32(vquo0, vm64);
        __m256i vval1 = _mm256_mul_epi32(vquo1, vm64);
        __m256i vval = VIntlv(vval0, vval1);
        __m256i vra = _mm256_sub_epi32(vlow, vval);
        __m256i vrb = _mm256_add_epi32(vra, vm32);
        __m256i vrc = _mm256_sub_epi32(vra, vm32);
        __m256i vmin = _mm256_min_epu32(vra, vrb);
        return _mm256_min_epu32(vmin, vrc);
    }
    
    inline __m256i VMul(__m256i va, __m256i vb0, __m256i vb1) {
        return VMul(VEx0(va), VEx1(va), vb0, vb1);
    }
    
    inline __m256i VMul(__m256i va, __m256i vb) {
        return VMul(va, VEx0(vb), VEx1(vb));
    }
    
    inline void VMul(U32* __restrict__ A, U32 Len, U32 W) {
        if (Len < 8) {
            for (U32 i = 0; i < Len; ++i)
                A[i] = MMul(A[i], W);
            return;
        }
        __m256i vw = _mm256_set1_epi64x(W);
        for (U32 i = 0; i < Len; i += 8)
            VSto(A + i, VMul(VLod(A + i), vw, vw));
    }
    
    inline void VMul(U32* __restrict__ A, const U32* __restrict__ B, U32 Len) {
        if (Len < 8) {
            for (U32 i = 0; i < Len; ++i)
                A[i] = MMul(A[i], B[i]);
            return;
        }
        for (U32 i = 0; i < Len; i += 8)
            VSto(A + i, VMul(VLod(A + i), VLod(B + i)));
    }
    
    inline void VSqr(U32* __restrict__ A, U32 Len) {
        if (Len < 8) {
            for (U32 i = 0; i < Len; ++i)
                A[i] = MMul(A[i], A[i]);
            return;
        }
        for (U32 i = 0; i < Len; i += 8) {
            __m256i va = VLod(A + i);
            __m256i v0 = VEx0(va);
            __m256i v1 = VEx1(va);
            VSto(A + i, VMul(v0, v1, v0, v1));
        }
    }
    
    U32 WbFwd[MaxExp + 1];
    U32 WbInv[MaxExp + 1];
    U32 LenInv[MaxExp + 1];
    
    inline void NttInitAll(int Max) {
        for (int Exp = 0; Exp <= Max; ++Exp) {
            WbFwd[Exp] = MPow(G, (P - 1) >> Exp);
            WbInv[Exp] = MPow(GInv, (P - 1) >> Exp);
            LenInv[Exp] = MInv(1u << Exp);
        }
    }
    
    inline void NttImpl1(U32* __restrict__ A, U32 Len) {
        for (U32 j = 0; j < Len; j += 2) {
            U32 a0 = MAdd(A[j + 0], A[j + 1]);
            U32 b0 = MSub(A[j + 0], A[j + 1]);
            A[j + 0] = a0;
            A[j + 1] = b0;
        }
    }
    
    inline void NttFwd2(U32* __restrict__ A, U32 Len, U32 Wn) {
        for (U32 j = 0; j < Len; j += 4) {
            U32 a0 = MAdd(A[j + 0], A[j + 2]);
            U32 a1 = MAdd(A[j + 1], A[j + 3]);
            U32 b0 = MSub(A[j + 0], A[j + 2]);
            U32 b1 = MSub(A[j + 1], A[j + 3]);
            A[j + 0] = a0;
            A[j + 1] = a1;
            A[j + 2] = b0;
            A[j + 3] = MMul(b1, Wn);
        }
    }
    
    inline void NttFwd3(U32* __restrict__ A, U32 Len, U32 Wn) {
        U32 W2 = MMul(Wn, Wn);
        U32 W3 = MMul(W2, Wn);
        const __m128i vm32 = _mm_set1_epi32(P);
        for (U32 j = 0; j < Len; j += 8) {
            __m128i va = _mm_load_si128((const __m128i*) (A + j));
            __m128i vb = _mm_load_si128((const __m128i*) (A + j + 4));
            __m128i vc = _mm_add_epi32(va, vb);
            __m128i vd = _mm_sub_epi32(va, vb);
            __m128i ve = _mm_sub_epi32(vc, _mm_andnot_si128(_mm_cmpgt_epi32(vm32, vc), vm32));
            __m128i vf = _mm_add_epi32(vd, _mm_and_si128(_mm_cmpgt_epi32(vb, va), vm32));
            _mm_store_si128((__m128i*) (A + j), ve);
            _mm_store_si128((__m128i*) (A + j + 4), vf);
            A[j + 5] = MMul(Wn, A[j + 5]);
            A[j + 6] = MMul(W2, A[j + 6]);
            A[j + 7] = MMul(W3, A[j + 7]);
        }
    }
    
    inline void NttFwd(U32* __restrict__ A, int Exp) {
        U32 Len = 1u << Exp;
        U32 Wn = WbFwd[Exp];
        for (int i = Exp - 1; i >= 3; --i) {
            U32 ChkSiz = 1u << i;
            U32 tw2 = MMul(Wn, Wn);
            U32 tw3 = MMul(tw2, Wn);
            U32 tw4 = MMul(tw3, Wn);
            U32 tw5 = MMul(tw4, Wn);
            U32 tw6 = MMul(tw5, Wn);
            U32 tw7 = MMul(tw6, Wn);
            U32 twn = MMul(tw7, Wn);
            __m256i vw32 = _mm256_set_epi32(tw7, tw6, tw5, tw4, tw3, tw2, Wn, 1);
            __m256i vwn = _mm256_set1_epi64x(twn);
            for (U32 j = 0; j < Len; j += 2u << i) {
                U32* A_ = A + j;
                U32* B_ = A_ + ChkSiz;
                __m256i vw = vw32;
                for (U32 k = 0; k < ChkSiz; k += 8) {
                    __m256i va = VLod(A_ + k);
                    __m256i vb = VLod(B_ + k);
                    __m256i vw0 = VEx0(vw);
                    __m256i vw1 = VEx1(vw);
                    __m256i vc = VAdd(va, vb);
                    __m256i vd = VSub(va, vb);
                    VSto(A_ + k, vc);
                    VSto(B_ + k, VMul(vd, vw0, vw1));
                    vw = VMul(vw0, vw1, vwn, vwn);
                }
            }
            Wn = MMul(Wn, Wn);
        }
        if (Exp >= 3) {
            NttFwd3(A, Len, Wn);
            Wn = MMul(Wn, Wn);
        }
        if (Exp >= 2)
            NttFwd2(A, Len, Wn);
        if (Exp)
            NttImpl1(A, Len);
    }
    
    inline void NttInv2(U32* __restrict__ A, U32 Len, U32 Wn) {
        for (U32 j = 0; j < Len; j += 4) {
            U32 a0 = A[j + 0];
            U32 a1 = A[j + 1];
            U32 b0 = A[j + 2];
            U32 b1 = MMul(A[j + 3], Wn);
            A[j + 0] = MAdd(a0, b0);
            A[j + 1] = MAdd(a1, b1);
            A[j + 2] = MSub(a0, b0);
            A[j + 3] = MSub(a1, b1);
        }
    }
    
    inline void NttInv3(U32* __restrict__ A, U32 Len, U32 Wn) {
        U32 W2 = MMul(Wn, Wn);
        U32 W3 = MMul(W2, Wn);
        const __m128i vm32 = _mm_set1_epi32(P);
        for (U32 j = 0; j < Len; j += 8) {
            A[j + 5] = MMul(Wn, A[j + 5]);
            A[j + 6] = MMul(W2, A[j + 6]);
            A[j + 7] = MMul(W3, A[j + 7]);
            __m128i va = _mm_load_si128((const __m128i*) (A + j));
            __m128i vb = _mm_load_si128((const __m128i*) (A + j + 4));
            __m128i vc = _mm_add_epi32(va, vb);
            __m128i vd = _mm_sub_epi32(va, vb);
            __m128i ve = _mm_sub_epi32(vc, _mm_andnot_si128(_mm_cmpgt_epi32(vm32, vc), vm32));
            __m128i vf = _mm_add_epi32(vd, _mm_and_si128(_mm_cmpgt_epi32(vb, va),vm32));
            _mm_store_si128((__m128i*) (A + j), ve);
            _mm_store_si128((__m128i*) (A + j + 4), vf);
        }
    }
    
    inline void NttInv(U32* __restrict__ A, int Exp) {
        if (!Exp)
            return;
        U32 Len = 1u << Exp;
        NttImpl1(A, Len);
        if (Exp == 1) {
            VMul(A, Len, LenInv[1]);
            return;
        }
        U32 Ws[MaxExp];
        Ws[0] = WbInv[Exp];
        for (int i = 1; i < Exp; ++i)
            Ws[i] = MMul(Ws[i - 1], Ws[i - 1]);
        NttInv2(A, Len, Ws[Exp - 2]);
        if (Exp == 2) {
            VMul(A, Len, LenInv[2]);
            return;
        }
        NttInv3(A, Len, Ws[Exp - 3]);
        if (Exp == 3) {
            VMul(A, Len, LenInv[3]);
            return;
        }
        for (int i = 3; i < Exp; ++i) {
            U32 ChkSiz = 1u << i;
            U32 Wn = Ws[Exp - 1 - i];
            U32 tw2 = MMul(Wn, Wn);
            U32 tw3 = MMul(tw2, Wn);
            U32 tw4 = MMul(tw3, Wn);
            U32 tw5 = MMul(tw4, Wn);
            U32 tw6 = MMul(tw5, Wn);
            U32 tw7 = MMul(tw6, Wn);
            U32 twn = MMul(tw7, Wn);
            __m256i vw32 = _mm256_set_epi32(tw7, tw6, tw5, tw4, tw3, tw2, Wn, 1);
            __m256i vwn = _mm256_set1_epi64x(twn);
            for (U32 j = 0; j < Len; j += 2u << i) {
                U32* A_ = A + j;
                U32* B_ = A_ + ChkSiz;
                __m256i vw = vw32;
                for (U32 k = 0; k < ChkSiz; k += 8) {
                    __m256i vw0 = VEx0(vw);
                    __m256i vw1 = VEx1(vw);
                    __m256i vb = VMul(VLod(B_ + k), vw0, vw1);
                    vw = VMul(vw0, vw1, vwn, vwn);
                    __m256i va = VLod(A_ + k);
                    __m256i vc = VAdd(va, vb);
                    __m256i vd = VSub(va, vb);
                    VSto(A_ + k, vc);
                    VSto(B_ + k, vd);
                }
            }
        }
        VMul(A, Len, LenInv[Exp]);
    }
    
    inline int Log2Ceil(U32 N) {
        static const U8 Table[32] = {
            0,  9,  1,  10, 13, 21, 2,  29,
            11, 14, 16, 18, 22, 25, 3,  30,
            8,  12, 20, 28, 15, 17, 24, 7,
            19, 27, 23, 6,  26, 5,  4,  31,
        };
        N = (N << 1) - 1;
        N |= N >> 1;
        N |= N >> 2;
        N |= N >> 4;
        N |= N >> 8;
        N |= N >> 16;
        return (int) Table[(N * 0x07c4acddu) >> 27];
    }
    
    U32 f[MAXN],h[MAXN];
    
    void clear(U32 *a,int n)
    {
        int tot=1;
        while(tot<n) tot<<=1;
        memset(a,0,sizeof(U32)*tot);
    }
    
    namespace poly_multiply{U32 t1[MAXN],t2[MAXN],t3[MAXN],t4[MAXN];}
    namespace poly_inverse{U32 t1[MAXN],t2[MAXN];}
    namespace poly_sqrt{U32 t1[MAXN],t2[MAXN];}
    
    void Mul(U32 *a,U32 *b,U32 *Res,int n)
    {
        using namespace poly_multiply;
        int tot=1,lg2=0;
        while(tot<2*n) tot<<=1,lg2++;
        clear(t1,tot);clear(t2,tot);clear(Res,tot);
        for(int i=0;i<n;i++) t1[i]=a[i],t2[i]=b[i];
        NttFwd(t1,lg2);NttFwd(t2,lg2);
        for(int i=0;i<tot;i++) Res[i]=MMul(t1[i],t2[i]);
        NttInv(Res,lg2);
    }
    
    void Mul(U32 *a,U32 *b,U32 *c,U32 *Res,int n)
    {
        using namespace poly_multiply;
        int tot=1,lg2=0;
        while(tot<2*n) tot<<=1,lg2++;
        clear(t1,tot);clear(t2,tot);clear(t3,tot);clear(Res,tot);
        for(int i=0;i<n;i++) t1[i]=a[i],t2[i]=b[i],t3[i]=c[i];
        NttFwd(t1,lg2);NttFwd(t2,lg2);NttFwd(t3,lg2);
        for(int i=0;i<tot;i++) Res[i]=MMul(MMul(t1[i],t2[i]),t3[i]);
        NttInv(Res,lg2);
    }
    
    void Inv(U32 *a,U32 *Res,int n)
    {
        using namespace poly_inverse;
        clear(Res,n);clear(t1,n);clear(t2,n);
        Res[0]=MPow(a[0],P-2);
        int l=1;
        while(l<n)
        {
            for(int i=0;i<l;i++) t1[i]=MAdd(Res[i],Res[i]);
            l<<=1;
            Mul(Res,Res,a,t2,l);
            for(int i=0;i<l;i++) Res[i]=MSub(t1[i],t2[i]);
        }
    }
    
    void Sqrt(U32 *a,U32 *Res,int n)
    {
        using namespace poly_sqrt;
        clear(Res,n);clear(t1,n);clear(t2,n);
        int l=1,tot=1;Res[0]=1;//a[0]=1
        while(tot<=n) tot<<=1;
        while(l<tot)
        {
            l<<=1;
            Mul(Res,Res,t1,l);Inv(Res,t2,l);
            for(int i=0;i<l;i++) t1[i]=MMul(MAdd(t1[i],a[i]),inv2);
            Mul(t1,t2,Res,l);
        }
    }
    
    void init()
    {
        NttInitAll(21);
        f[2]=1;f[1]=MSub(0,6);f[0]=1;
        int t=clock();
        Sqrt(f,h,MAXK);
        for(int i=0;i<=MAXK;i++) h[i]=MSub(0,h[i]);
        h[0]++;h[1]--;
        for(int i=0;i<=MAXK;i++) h[i]=MMul(h[i],inv2);
    }
    
    int main()
    {
        freopen("grid.in","r",stdin);
        freopen("grid.out","w",stdout);
        int T,n;
        for(T=io(),init();T--;n=io(),io(h[n]));
    }
    
  • 相关阅读:
    spring集成环境下的axis webservice的发布,调试
    Axis2 webservice 之使用java调用webservice
    Axis2 webservice入门--写个简单的webservice
    Axis2 webservice入门--开发环境搭建,概念理解
    使用js给页面显示的图片添加水印效果
    java使用dom4j解析xml文件
    一个Java递归程序
    Java-Scanner键盘输入
    JVM—调优参数学习
    JVM—类加载过程
  • 原文地址:https://www.cnblogs.com/Cwolf9/p/10317409.html
Copyright © 2011-2022 走看看