zoukankan      html  css  js  c++  java
  • 关于矩阵快速幂的若干优化

    首先,我们复习一下矩阵乘法。

    我们记3个矩阵A(a行b列),B(b行c列),C(a行c列)。我们要计算A*B,并把答案存到矩阵C中。

    C[i][j]+=A[i][k]*B[k][j](1<=i<=a,1<=j<=c,1<=k<=b),即新矩阵的第i行第j个元素是原1矩阵的第i行*原2矩阵的第j列得来的。

    一般来说,我们的计算方法是for(int i=1;i<=a;i++)for(int j=1;j<=c;j++)for(int k=1;k<=b;k++)C[i][j]+=A[i][k]*B[k][j];

    其次,让我们复习一下快速幂。

    举个例子吧,计算a^101。

    我们知道,于是:

    a^101=a^(1*2^6)*a^(1*2^5)*a^(0*2^4)*a^(1*2^3)*a^(1*2^2)*a^(0*2^1)*a^(1*2^0)。

    我们把101转成2进制:1101101。每个2^x前的系数就是二进制第x位的数。

    a^ (2^x)=a^(2^(x-1))^2。我们可以通过a^(2^(x-1))来求得a^(2^x)。

    这样,对于二进制下的第x位,该位如果为1,就把ans*=a(更新答案,初始化为1)。然后每次a*=a(用a^(2^x)更新出a^(2^(x+1)),准备处理下一位)。

    我们便可以在O(logp)(p为指数)的时间复杂度内出解。

    最后,让我们来复习一下矩阵快速幂。

    我们要求A^B^B^B^B^B^B^B......(A,B为矩阵),即A^(B^p)的值。

    就像ans初值=1一样,记一个单位矩阵(主对角线为1)Ans,结合上面两种做法,我们就可以求出A^(B^p)的值。

    (1)对于稀疏矩阵的优化

    稀疏矩阵,即为矩阵中有很多元素为0。

    优化方法:改变循环顺序。改为for(int i=1;i<=a;i++)for(int k=1;k<=b;k++)for(int j=1;j<=c;j++)C[i][j]+=A[i][k]*B[k][j];

    这样有什么好处呢?

    我们可以发现,只要A[i][k]==0,那么对答案矩阵(C)不会有任何贡献。

    所以我们可以进行优化,在第二个循环到第三个循环直接加一个if,若A[i][k]!=0,才进入第三个循环。

    for(int i=1;i<=a;i++)for(int k=1;k<=b;k++)if(A[i][k])for(int j=1;j<=c;j++)C[i][j]+=A[i][k]*B[k][j];

    题目:POJ 3735 Training little cats。

    (2)预处理优化矩阵快速幂

    主要针对多组数据。求A*B^k。给出A,B,T个询问k

    在通常情况下,A是一个n行1列的矩阵,B是一个n行n列的矩阵。这样,我们的矩阵快速幂(求A^(B^k))的复杂度就是O((n^3logk+Tn^2logk))。

    具体来说,我们先用O(n^3log(maxk))预处理出B^(2^p),再A*B^k=A*B^(2^a1)+A*B^(2^a2)+...算答案。复杂度O(Tn^2logk)

    (3)优化快速幂过程

    主要针对多组数据。

    正常的快速幂的当次复杂度为O(log2(n))。看到那个2了吗,我们的工作就是要把这个2变大。

    考虑一般的快速幂,一般的快速幂是以2进制为基础的,我们考虑用3进制为基础会怎么样。

    对于每一个3进制位,如果该位是0,ans*=x^0,如果该位是1,ans*=x^1,如果该位是2,ans*=x^2

    与2进制快速幂同理,每次x=x^3,p=p/3

    所以复杂度是O(klogk(n)),k为进制

    但是虽然这个2变大了,复杂度却一点也没变小

    但是这并不能阻挡我们优化的决心,如果每次询问的底数都相同,我们是能优化的

    预处理mi[a][b]=(x^(k^a))^b即可,每次ans*=mi[a][b],a是当前做到第几位,b是当前这位的数

    mi[a][1]=mi[a-1][k-1]*mi[a-1][1]

    mi[a][b]=mi[a][b-1]*mi[a][1]

    这样复杂度变为(klogk(n)+logk(n))。

    (4)常数优化

    ikj循环,循环展开 for(int i = 1; i <= n; i++) for(int k = 1; k <= n; k++) for(int j = 1; j <= n; j++) c[i][j] += a[i][k] * b[k][j];

    这样能保证b数组的内存访问是连续的

    拥有上面全部优化的模板题:https://www.luogu.org/problemnew/show/P5107

    #include <cstdio>
    #include <cstring>
    #define mod 998244353
    #define T 256
    #include <algorithm>
    
    struct xxx{
        int a[52][52];
    };
    struct xx{
        int a[52];
    };
    struct QQ{
        int x, id;
    }q[50100];
    int n, d[55];
    xxx mi[4][T + 1];
    xx ans;
    long long Ans[50100];
    
    bool cmp(QQ a, QQ b) {return a.x < b.x;}
     
    xxx operator * (xxx a, xxx b)
    {
        xxx c; memset(c.a, 0, sizeof(c.a));
        for(int i = 1; i <= n; i++)
            for(int k = 1; k <= n; k++)
                if(a.a[i][k])
                for(int j = 1; j <= n; j++)
                    c.a[i][j] = (c.a[i][j] + 1ll * a.a[i][k] * b.a[k][j]) % mod;
        return c;
    }
    
    xx operator * (xx a, xxx b) 
    {
        xx c; memset(c.a, 0, sizeof(c.a));
        for(int j = 1; j <= n; j++)
            for(int i = 1; i <= n; i++)
                c.a[j] = (c.a[j] + 1ll * a.a[i] * b.a[i][j]) % mod;
        return c;
    }
    
    int qpow(int x, int p)
    {
        int ans = 1;
        while(p)
        {
            if(p & 1) ans = 1ll * ans * x % mod;
            x = 1ll * x * x % mod; p >>= 1;
        }
        return ans;
    }
    
    xx operator ^ (xx a, int p)
    {
        int j = 0;
        while(p)
        {
            ans = ans * mi[j][p & 255];
            j++; p >>= 8;
        }
        return ans;
    }
    
    int main()
    {
        int m, Q; scanf("%d%d%d", &n, &m, &Q);
        for(int i = 1; i <= n; i++) scanf("%d", &ans.a[i]), mi[0][1].a[i][i] = 1, d[i] = 1;
        for(int i = 1; i <= m; i++)
        {
            int u, v; scanf("%d%d", &u, &v);
            mi[0][1].a[u][v]++; d[u]++;
        }
        for(int i = 1; i <= n; i++)
            for(int j = 1; j <= n; j++)
                mi[0][1].a[i][j] = 1ll * mi[0][1].a[i][j] * qpow(d[i], mod - 2) % mod;
        for(int i = 0; i <= 3; i++)
        {
            for(int j = 0; j < T; j++)
            {
                if(i == 0 && j == 1) continue;
                if(j == 0) for(int k = 1; k <= n; k++) mi[i][j].a[k][k] = 1;
                else if(j == 1) mi[i][j] = mi[i - 1][T - 1] * mi[i - 1][1];
                else mi[i][j] = mi[i][j - 1] * mi[i][1];
            }
        }
        for(int i = 1; i <= Q; i++)
        {
            scanf("%d", &q[i].x);
            q[i].id = i;
        }
        std::sort(q + 1, q + Q + 1, cmp);
        for(int i = 1; i <= Q; i++)
        {
            ans = ans ^ (q[i].x - q[i - 1].x);
            for(int j = 1; j <= n; j++) Ans[q[i].id] = Ans[q[i].id] ^ ans.a[j];
            Ans[q[i].id] %= mod;
        }
        for(int i = 1; i <= Q; i++) printf("%lld
    ", Ans[i]);
    }
  • 相关阅读:
    高级特性(4)- 数据库编程
    UVA Jin Ge Jin Qu hao 12563
    UVA 116 Unidirectional TSP
    HDU 2224 The shortest path
    poj 2677 Tour
    【算法学习】双调欧几里得旅行商问题(动态规划)
    南洋理工大学 ACM 在线评测系统 矩形嵌套
    UVA The Tower of Babylon
    uva A Spy in the Metro(洛谷 P2583 地铁间谍)
    洛谷 P1095 守望者的逃离
  • 原文地址:https://www.cnblogs.com/lher/p/8024949.html
Copyright © 2011-2022 走看看