zoukankan      html  css  js  c++  java
  • SHUOJ Arithmetic Sequence (FFT)

    链接:http://acmoj.shu.edu.cn/problem/533/

    题意:求一个序列中,有多少三元组(其中元素不重复)在任意的排列下能构成等差数列。
    分析:等差数列:(A_j-A_i=A_k-A_j),即(2A_j=A_i+A_k),枚举(A_i+A_j)的所有情况对应的个数,再扫一遍求解。
    先统计出每个数对应的出现次数,FFT计算出和的组合情况。但是要减去(A_i+A_i)得到的结果以及(A_i+A_j)以及(A_j+A_i)重复的计算。
    现在对于数(A_j),假设(cnt=2*A_j)的系数,当然cnt中要减去(A_j)本身和一个值与(A_j)相等的数组合而成的情况。枚举完这个数以后,把这个数从序列中抹去,因为这个数对结果做出的贡献已经计算,之后的统计中该数以及该数对结果的贡献不能重复计算。

    #include <bits/stdc++.h>
    using namespace std;
    typedef long long LL;
    const int MAXN = 1e5 + 10;
    const double PI = acos(-1.0);
    struct Complex{
        double x, y;
        inline Complex operator+(const Complex b) const {
            return (Complex){x +b.x,y + b.y};
        }
        inline Complex operator-(const Complex b) const {
            return (Complex){x -b.x,y - b.y};
        }
        inline Complex operator*(const Complex b) const {
            return (Complex){x *b.x -y * b.y,x * b.y + y * b.x};
        }
    } va[MAXN * 2 + MAXN / 2], vb[MAXN * 2 + MAXN / 2];
    int lenth = 1, rev[MAXN * 2 + MAXN / 2];
    int N, M;   // f 和 g 的数量
        //f g和 的系数
        // 卷积结果
        // 大数乘积
    int f[MAXN],g[MAXN];
    vector<LL> conv;
    vector<LL> multi;
    void debug(){for(int i=0;i<conv.size();++i) cout<<conv[i]<<" ";cout<<endl;}
    //f g
    void init()
    {
        int tim = 0;
        lenth = 1;
        conv.clear(), multi.clear();
        memset(va, 0, sizeof va);
        memset(vb, 0, sizeof vb);
        while (lenth <= N + M - 2)
            lenth <<= 1, tim++;
        for (int i = 0; i < lenth; i++)
            rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (tim - 1));
    }
    void FFT(Complex *A, const int fla)
    {
        for (int i = 0; i < lenth; i++){
            if (i < rev[i]){
                swap(A[i], A[rev[i]]);
            }
        }
        for (int i = 1; i < lenth; i <<= 1){
            const Complex w = (Complex){cos(PI / i), fla * sin(PI / i)};
            for (int j = 0; j < lenth; j += (i << 1)){
                Complex K = (Complex){1, 0};
                for (int k = 0; k < i; k++, K = K * w){
                    const Complex x = A[j + k], y = K * A[j + k + i];
                    A[j + k] = x + y;
                    A[j + k + i] = x - y;
                }
            }
        }
    }
    void getConv(){             //求多项式
        init();
        for (int i = 0; i < N; i++)
            va[i].x = f[i];
        for (int i = 0; i < M; i++)
            vb[i].x = g[i];
        FFT(va, 1), FFT(vb, 1);
        for (int i = 0; i < lenth; i++)
            va[i] = va[i] * vb[i];
        FFT(va, -1);
        for (int i = 0; i <= N + M - 2; i++)
            conv.push_back((LL)(va[i].x / lenth + 0.5));
    }
    
    void getMulti()             //求A*B
    {
        getConv();
        multi = conv;
        reverse(multi.begin(), multi.end());
        multi.push_back(0);
        int sz = multi.size();
        for (int i = 0; i < sz - 1; i++){
            multi[i + 1] += multi[i] / 10;
            multi[i] %= 10;
        }
        while (!multi.back() && multi.size() > 1)
            multi.pop_back();
        reverse(multi.begin(), multi.end());
    }
    
    int a[MAXN];
    int cnt[MAXN];
    int main()
    {
        #ifndef ONLINE_JUDGE
            freopen("in.txt","r",stdin);
            freopen("out.txt","w",stdout);
        #endif
        int T; scanf("%d",&T);
        while(T--){
            int n; scanf("%d",&n);
            int mx = -1;
            memset(cnt,0,sizeof(cnt));
            for(int i =1;i<=n;++i){
                scanf("%d",&a[i]);
                mx = max(mx,a[i]);
                cnt[a[i]]++;
            }
            N = M = mx+1;
            for(int i=0;i<N;++i){
                f[i] = g[i] = cnt[i];
            }
            getConv();
            int sz = conv.size();
            for(int i=1;i<=n;++i){
                conv[a[i]*2]--;
            }
            for(int i=0;i<sz;++i){
                conv[i]>>=1;
            }
            LL res=0;
            //debug();
            //sort(a+1,a+n+1);
            for(int i=1;i<=n;++i){
                if(2*a[i]>=sz) continue;
                LL tmp = conv[2*a[i]];
                tmp -= cnt[a[i]]-1;           //减去由自己构成的
                conv[2*a[i]] -= cnt[a[i]]-1;    //将Ai对结果的贡献抹去
                cnt[a[i]]--;                    //将Ai从原序列中抹去
                res += tmp;
            }
            printf("%lld
    ",res);
        }
        return 0;
    }
    
    
  • 相关阅读:
    mysql 统计数据库基本资源sql
    java ffmpeg (Linux)截取视频做封面
    shutil模块
    json模块与pickle模块
    hashlib模块
    sys模块
    os模块
    paramiko模块
    Python reduce() 函数
    瀑布流展示图片
  • 原文地址:https://www.cnblogs.com/xiuwenli/p/9719425.html
Copyright © 2011-2022 走看看