zoukankan      html  css  js  c++  java
  • 关于矩阵快速幂的若干优化

    首先,我们复习一下矩阵乘法。

    我们记3个矩阵A(a行b列),B(b行c列),C(a行c列)。我们要计算A*B,并把答案存到矩阵C中。

    C[i][j]+=A[i][k]*B[k][j](1<=i<=a,1<=j<=c,1<=k<=b),即新矩阵的第i行第j个元素是原1矩阵的第i行*原2矩阵的第j列得来的。

    一般来说,我们的计算方法是for(int i=1;i<=a;i++)for(int j=1;j<=c;j++)for(int k=1;k<=b;k++)C[i][j]+=A[i][k]*B[k][j];

    其次,让我们复习一下快速幂。

    举个例子吧,计算a^101。

    我们知道,于是:

    a^101=a^(1*2^6)*a^(1*2^5)*a^(0*2^4)*a^(1*2^3)*a^(1*2^2)*a^(0*2^1)*a^(1*2^0)。

    我们把101转成2进制:1101101。每个2^x前的系数就是二进制第x位的数。

    a^ (2^x)=a^(2^(x-1))^2。我们可以通过a^(2^(x-1))来求得a^(2^x)。

    这样,对于二进制下的第x位,该位如果为1,就把ans*=a(更新答案,初始化为1)。然后每次a*=a(用a^(2^x)更新出a^(2^(x+1)),准备处理下一位)。

    我们便可以在O(logp)(p为指数)的时间复杂度内出解。

    最后,让我们来复习一下矩阵快速幂。

    我们要求A^B^B^B^B^B^B^B......(A,B为矩阵),即A^(B^p)的值。

    就像ans初值=1一样,记一个单位矩阵(主对角线为1)Ans,结合上面两种做法,我们就可以求出A^(B^p)的值。

    (1)对于稀疏矩阵的优化

    稀疏矩阵,即为矩阵中有很多元素为0。

    优化方法:改变循环顺序。改为for(int i=1;i<=a;i++)for(int k=1;k<=b;k++)for(int j=1;j<=c;j++)C[i][j]+=A[i][k]*B[k][j];

    这样有什么好处呢?

    我们可以发现,只要A[i][k]==0,那么对答案矩阵(C)不会有任何贡献。

    所以我们可以进行优化,在第二个循环到第三个循环直接加一个if,若A[i][k]!=0,才进入第三个循环。

    for(int i=1;i<=a;i++)for(int k=1;k<=b;k++)if(A[i][k])for(int j=1;j<=c;j++)C[i][j]+=A[i][k]*B[k][j];

    题目:POJ 3735 Training little cats。

    (2)预处理优化矩阵快速幂

    主要针对多组数据。求A*B^k。给出A,B,T个询问k

    在通常情况下,A是一个n行1列的矩阵,B是一个n行n列的矩阵。这样,我们的矩阵快速幂(求A^(B^k))的复杂度就是O((n^3logk+Tn^2logk))。

    具体来说,我们先用O(n^3log(maxk))预处理出B^(2^p),再A*B^k=A*B^(2^a1)+A*B^(2^a2)+...算答案。复杂度O(Tn^2logk)

    (3)优化快速幂过程

    主要针对多组数据。

    正常的快速幂的当次复杂度为O(log2(n))。看到那个2了吗,我们的工作就是要把这个2变大。

    考虑一般的快速幂,一般的快速幂是以2进制为基础的,我们考虑用3进制为基础会怎么样。

    对于每一个3进制位,如果该位是0,ans*=x^0,如果该位是1,ans*=x^1,如果该位是2,ans*=x^2

    与2进制快速幂同理,每次x=x^3,p=p/3

    所以复杂度是O(klogk(n)),k为进制

    但是虽然这个2变大了,复杂度却一点也没变小

    但是这并不能阻挡我们优化的决心,如果每次询问的底数都相同,我们是能优化的

    预处理mi[a][b]=(x^(k^a))^b即可,每次ans*=mi[a][b],a是当前做到第几位,b是当前这位的数

    mi[a][1]=mi[a-1][k-1]*mi[a-1][1]

    mi[a][b]=mi[a][b-1]*mi[a][1]

    这样复杂度变为(klogk(n)+logk(n))。

    (4)常数优化

    ikj循环,循环展开 for(int i = 1; i <= n; i++) for(int k = 1; k <= n; k++) for(int j = 1; j <= n; j++) c[i][j] += a[i][k] * b[k][j];

    这样能保证b数组的内存访问是连续的

    拥有上面全部优化的模板题:https://www.luogu.org/problemnew/show/P5107

    #include <cstdio>
    #include <cstring>
    #define mod 998244353
    #define T 256
    #include <algorithm>
    
    struct xxx{
        int a[52][52];
    };
    struct xx{
        int a[52];
    };
    struct QQ{
        int x, id;
    }q[50100];
    int n, d[55];
    xxx mi[4][T + 1];
    xx ans;
    long long Ans[50100];
    
    bool cmp(QQ a, QQ b) {return a.x < b.x;}
     
    xxx operator * (xxx a, xxx b)
    {
        xxx c; memset(c.a, 0, sizeof(c.a));
        for(int i = 1; i <= n; i++)
            for(int k = 1; k <= n; k++)
                if(a.a[i][k])
                for(int j = 1; j <= n; j++)
                    c.a[i][j] = (c.a[i][j] + 1ll * a.a[i][k] * b.a[k][j]) % mod;
        return c;
    }
    
    xx operator * (xx a, xxx b) 
    {
        xx c; memset(c.a, 0, sizeof(c.a));
        for(int j = 1; j <= n; j++)
            for(int i = 1; i <= n; i++)
                c.a[j] = (c.a[j] + 1ll * a.a[i] * b.a[i][j]) % mod;
        return c;
    }
    
    int qpow(int x, int p)
    {
        int ans = 1;
        while(p)
        {
            if(p & 1) ans = 1ll * ans * x % mod;
            x = 1ll * x * x % mod; p >>= 1;
        }
        return ans;
    }
    
    xx operator ^ (xx a, int p)
    {
        int j = 0;
        while(p)
        {
            ans = ans * mi[j][p & 255];
            j++; p >>= 8;
        }
        return ans;
    }
    
    int main()
    {
        int m, Q; scanf("%d%d%d", &n, &m, &Q);
        for(int i = 1; i <= n; i++) scanf("%d", &ans.a[i]), mi[0][1].a[i][i] = 1, d[i] = 1;
        for(int i = 1; i <= m; i++)
        {
            int u, v; scanf("%d%d", &u, &v);
            mi[0][1].a[u][v]++; d[u]++;
        }
        for(int i = 1; i <= n; i++)
            for(int j = 1; j <= n; j++)
                mi[0][1].a[i][j] = 1ll * mi[0][1].a[i][j] * qpow(d[i], mod - 2) % mod;
        for(int i = 0; i <= 3; i++)
        {
            for(int j = 0; j < T; j++)
            {
                if(i == 0 && j == 1) continue;
                if(j == 0) for(int k = 1; k <= n; k++) mi[i][j].a[k][k] = 1;
                else if(j == 1) mi[i][j] = mi[i - 1][T - 1] * mi[i - 1][1];
                else mi[i][j] = mi[i][j - 1] * mi[i][1];
            }
        }
        for(int i = 1; i <= Q; i++)
        {
            scanf("%d", &q[i].x);
            q[i].id = i;
        }
        std::sort(q + 1, q + Q + 1, cmp);
        for(int i = 1; i <= Q; i++)
        {
            ans = ans ^ (q[i].x - q[i - 1].x);
            for(int j = 1; j <= n; j++) Ans[q[i].id] = Ans[q[i].id] ^ ans.a[j];
            Ans[q[i].id] %= mod;
        }
        for(int i = 1; i <= Q; i++) printf("%lld
    ", Ans[i]);
    }
  • 相关阅读:
    『转』QueryPerformanceFrequency()
    『转』C++中虚析构函数的作用
    存储过程的优缺点
    一个工作7年的软件工程师的总结(收藏)
    存储过程分页算法(收藏)
    Ajax原理(收藏)
    七大秘籍成就职场王者(收藏)
    视图的优缺点
    SQL索引全攻略
    .aspx、MasterPage、.ascx加载顺序
  • 原文地址:https://www.cnblogs.com/lher/p/8024949.html
Copyright © 2011-2022 走看看