zoukankan      html  css  js  c++  java
  • [黑科技]常数优化的一些技巧

    感谢wys和小火车普及这些技巧qwq 这篇文章大概没什么营养

    我们来看一道十分简单的题目:

    设n=131072,输入两个长度为n的数列$a_0,a_1...a_{n-1}$和$b_0,b_1...b_{n-1}$,要求输出一个长度为n的数列$c_0,c_1...c_{n-1}$ 。

    其中$c_i=max(a_j+b_{i-j})~(0 leq j leq i)$ ,$1 leq a_i, b_i leq 10^9$ 。

    首先我们来讲讲这题怎么做。

    如果数据是随机的,那么有一种神奇的做法:在a和b中分别挑出最大的p个元素,对于每个i暴力枚举每个p进行更新,这样的复杂度是O(np)的,正确性我不会分析= =

    那么数据不是随机的...那么估计没有什么快速的算法,不如暴力!

    以下的运行时间均为在我的渣渣笔记本中测试得到,仅供参考。测试环境Ubuntu,编译选项只有-O2。

    0.

    #define SZ 666666
    int a[SZ],b[SZ],c[SZ];
    const int n=131072;
    int main()
    {
        for(int i=0;i<n;i++) scanf("%d",a+i);
        for(int i=0;i<n;i++) scanf("%d",b+i);
        for(int i=0;i<=n;i++)
        {
            for(int j=0;j<=i;j++) c[i]=max(c[i],a[j]+b[i-j]);
        }
        for(int i=0;i<n;i++) printf("%d ",c[i]);puts("");
        cerr<<clock()<<"ms
    ";
    }

    simple and stupid。我们来测试一下...跑了8s。虽然不是太糟,但是还是很慢...我们来进行一波有理有据的常数优化吧。

    1. 加上输入输出优化

    const int n=131072;
    char ch,B[1<<20],*S=B,*T=B;
    #define getc() (S==T&&(T=(S=B)+fread(B,1,1<<20,stdin),S==T)?0:*S++)
    #define isd(c) (c>='0'&&c<='9')
    int aa,bb;int F(){
        while(ch=getc(),!isd(ch)&&ch!='-');ch=='-'?aa=bb=0:(aa=ch-'0',bb=1);
        while(ch=getc(),isd(ch))aa=aa*10+ch-'0';return bb?aa:-aa;
    }
    #define gi F()
    #define BUFSIZE 5000000
    namespace fob {char b[BUFSIZE]={},*f=b,*g=b+BUFSIZE-2;}
    #define pob (fwrite(fob::b,sizeof(char),fob::f-fob::b,stdout),fob::f=fob::b,0)
    #define pc(x) (*(fob::f++)=(x),(fob::f==fob::g)?pob:0)
    struct foce {~foce() {pob; fflush(stdout);}} _foce;
    namespace ib {char b[100];}
    inline void pint(int x)
    {
        if(x==0) {pc(48); return;}
        //if(x<0) {pc('-'); x=-x;} //如果有负数就加上 
        char *s=ib::b;
        while(x) *(++s)=x%10, x/=10;
        while(s!=ib::b) pc((*(s--))+48);
    }
    int main()
    {
        for(int i=0;i<n;i++) a[i]=gi;
        for(int i=0;i<n;i++) b[i]=gi;
        for(int i=0;i<=n;i++)
            for(int j=0;j<=i;j++) c[i]=max(c[i],a[j]+b[i-j]);
        for(int i=0;i<n;i++) pint(c[i]),pc(' ');pc(10);
        cerr<<clock()<<"ms
    ";
    }

    虽然看起来只有10w多个数,我们还是加一波输入输出优化试试...

    居然跑了10s。比原来还慢...这和预期不太相符啊...在windows上加了输入输出确实会变快,但是ubuntu下变慢了...大概输入输出少的时候最好还是不要加优化?

    以下的测试全部基于输入输出优化,就假装加了优化跑的更快好了。

    2. 手写stl

    虽然这段代码非常短,但是我们还是使用了一个stl:max。我们来常数优化一波!

    for(int i=0;i<=n;i++)
            for(int j=0;j<=i;j++)
                if(a[j]+b[i-j]>c[i])
                    c[i]=a[j]+b[i-j];

    测了测,跑了5.3s,比原来快了快一半!可喜可贺。

    3. 把if改成三目?

    这时候我想起了wys的教导:少用if,多用三目。

    for(int i=0;i<=n;i++)
            for(int j=0;j<=i;j++)
                (a[j]+b[i-j]>c[i])
                ?(c[i]=a[j]+b[i-j]):0;

    这样写跑了6.1s,居然比if还慢?

    有理有据的分析:正常情况下,if改成三目会变快的原因是因为消除了分支预测,分支预测错误跳转的代价很大,而上面那段代码预测错误几率很小,所以if就比较快了。

    4. 循环展开

    为了写起来方便,首先我们将b数组反序,这样可以减少运算量,接下来把内层j循环展开。

    int main()
    {
        for(register int i=0;i<n;i++) a[i]=gi;
        for(register int i=0;i<n;i++) b[n-i]=gi;
        for(register int i=0;i<n;i++)
        {
            int*r=b+n-i;
            for(register int j=0;j<=i;j+=8)
            {
            #define chk(a,b,c) if(a+b>c) c=a+b;
            #define par(p) chk(a[p],b[p],c[i])
            par(j) par(j+1) par(j+2) par(j+3)
            par(j+4) par(j+5) par(j+6) par(j+7)
            }
        }
        for(register int i=0;i<n;i++) pint(c[i]),pc(' ');pc(10);
        cerr<<clock()<<"ms
    ";
    }

    这样理论上cpu可以对中间的代码乱序执行,就是一次执行很多条,从而提高运行速度。

    实测优化效果非常好,只跑了2.9s,比原来快了1倍多。

    此外我还了解到openmp和cache blocking这两种优化方法,但是对程序提速不明显,这里就不提了,有兴趣的自行度娘。

    5. Intrinsic

    这是真正的黑科技了= =orz小火车

    #include "immintrin.h"
    #include "emmintrin.h"
    static __m256i a_m[SZ],b_m[8][SZ];
    static int a[SZ],b[SZ],c[SZ];
    __attribute__((target("avx2")))
    inline int gmax(__m256i qwq)
    {
        int*g=(int*)&qwq,ans=0;
        (g[0]>ans)?(ans=g[0]):0;
        (g[1]>ans)?(ans=g[1]):0;
        (g[2]>ans)?(ans=g[2]):0;
        (g[3]>ans)?(ans=g[3]):0;
        (g[4]>ans)?(ans=g[4]):0;
        (g[5]>ans)?(ans=g[5]):0;
        (g[6]>ans)?(ans=g[6]):0;
        (g[7]>ans)?(ans=g[7]):0;
        return ans;
    }
    __attribute__((target("avx2")))
    int main()
    {
        const int n=131072;
        memset(a,-127/3,sizeof(a));
        memset(b,-127/3,sizeof(b));
        for(register int i=0;i<n;i++) a[i]=gi;
        for(register int i=0;i<n;i++) b[n-i]=gi;
        for(register int i=0;i<=n+5;i+=8)
            a_m[i>>3]=_mm256_set_epi32
            (a[i],a[i+1],a[i+2],a[i+3],
            a[i+4],a[i+5],a[i+6],a[i+7]);
        for(register int r=0;r<8;++r)
        for(register int i=0;i<=n+67;i+=8)
            b_m[r][i>>3]=_mm256_set_epi32
            (b[i+r],b[i+1+r],b[i+2+r],b[i+3+r],
            b[i+4+r],b[i+5+r],b[i+6+r],b[i+7+r]);
        __m256i zero=_mm256_set_epi32(0,0,0,0,0,0,0,0);
        for(register int i=0,lj;i<n;i++)
        {
            __m256i*r=b_m[(n-i)&7]+((n-i)>>3),
            qwq=zero; lj=(i>>3);
            for(register int j=0;j<=lj;j+=8)
            {
    #define par(p) qwq=_mm256_max_epi32(
    qwq,_mm256_add_epi32(a_m[p],r[p]));
                par(j) par(j+1) par(j+2) par(j+3)
                par(j+4) par(j+5) par(j+6) par(j+7)
            }
            c[i]=gmax(qwq);
        }
        for(register int i=0;i<n;i++)
            pint(c[i]),pc(' ');pc(10);
        cerr<<clock()<<"ms
    ";
    }

    这段代码有点长,这里解释一下原理。

    大家都知道stl中有一个很好用的库叫bitset,它的原理是将32/64个bit(取决于字长)压成一个数(long),从而使常数/=32或64。

    在intel部分指令集中,有类似的数据类型,可以将多个int/float/double等等压在一个128/256/512位的数据类型里,从而一起进行计算。

    大致有三种数据类型:

    __m128i __m256i __m512i

    分别对应压在128/256/512位内。

    我们可以对这三种数据类型中压的数进行并行计算!例如一个__m256i里可以包8个int。这些数据类型的方法有点多,intel提供了一个可以查找这些方法的页面:https://software.intel.com/sites/landingpage/IntrinsicsGuide/

    实现起来相当于手写bitset,细节详见代码吧。

    这段代码只跑了1.4s!又比循环展开快了一倍。

    我实测了一下,在uoj上__m256i无法使用,__m128i只能使用部分指令,例如_mm_max_epi32这个指令就不支持......洛谷上可以正常运行。

    upd:一些卡常方面的tips:

    register基本没用。

    全局数组开static(有可能)可以放进L1~L3,明显加快速度。

    多维数组后面几维别开2的次幂,可能导致cache miss。

    inline的速度比手动展开或__attribute__((always_inline))慢。

    一般什么卡常都比不上输入输出优化效果好= =

  • 相关阅读:
    spark调度器FIFO,FAIR
    elasticsearch5.6.8 创建TransportClient工具类
    elasticsearch TransportClient bulk批量提交数据
    java 参数来带回方法运算结果
    idea上传代码到git本地仓库
    2020-03-01 助教一周小结(第三周)
    2020-02-23 助教一周小结(第二周)
    2020-02-16 助教一周小结(第一周)
    寻找两个有序数组的中位数
    无重复字符的最长子串
  • 原文地址:https://www.cnblogs.com/zzqsblog/p/6666755.html
Copyright © 2011-2022 走看看