首页 > 代码库 > 初涉矩阵快速幂

初涉矩阵快速幂

一般用来加速递推。

简单的,对于fib数列有,f0 = 1,f1 = 1,fn = fn-1 + fn-2(n >= 2)。
则对于f
n有:

一般的,对于fn = A1*f(n-1) + A2*f(n-2)  + .... +A(n-1)*f1,有:


又因为矩阵乘法满足结合律,所以可以用快速幂来求A^n,从而达到递推的效果。


顺便即一个小技巧:


以POJ 3233为例

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <queue>
#include <cmath>
#include <stack>
#include <map>

#pragma comment(linker, "/STACK:1024000000");
#define EPS (1e-8)
#define LL long long
#define ULL unsigned long long
#define _LL __int64
#define INF 0x3f3f3f3f

using namespace std;

int Mod;

const int MAXN = 61;

struct Mat
{
    LL mat[MAXN][MAXN];
    int r,c;

    void Init(int val,int R,int C)
    {
        r = R,c = C;
        for(int i = 1;i <= r; ++i)
            for(int j = 1;j <= c; ++j)
                if(i != j)
                    mat[i][j] = 0;
                else
                    mat[i][j] = val;
    }
};

Mat MatrixMult(Mat a,Mat b)
{
    Mat p;
    p.Init(0,a.r,b.c);

    for(int i = 1;i <= a.r; ++i)
    {
        for(int j = 1;j <= b.c; ++j)
        {
            for(int k = 1;k <= b.r; ++k)
            {
                p.mat[i][j] += a.mat[i][k]*b.mat[k][j];
                p.mat[i][j] %= Mod;
            }
        }
    }

    return p;
}

Mat QuickMult(_LL k,Mat coe)
{
    Mat p;

    p.Init(1,coe.r,coe.c);

    while(k >= 1)
    {
        if(k&1)
            p = MatrixMult(p,coe);
        coe = MatrixMult(coe,coe);
        k >>= 1;
    }

    return p;
}

int main()
{
    _LL n,k,m;

    int i,j;

    Mat A,B;

    scanf("%lld %lld %lld",&n,&k,&m);

    Mod = m;

    for(i = 1;i <= n; ++i)
    {
        for(j = 1;j <= n; ++j)
            scanf("%lld",&A.mat[i][j]);
    }

    for(i = 1;i <= n; ++i)
    {
        for(j = 1;j <= n; ++j)
        {
            if(i == j)
                A.mat[i][j+n] = 1;
            else
                A.mat[i][j+n] = 0;
        }
    }

    for(i = 1;i <= n; ++i)
    {
        for(j = 1;j <= n; ++j)
        {
            A.mat[i+n][j] = 0;
        }
    }

    for(i = 1;i <= n; ++i)
    {
        for(j = 1;j <= n; ++j)
        {
            if(i == j)
                A.mat[i+n][j+n] = 1;
            else
                A.mat[i+n][j+n] = 0;
        }
    }

    A.r = 2*n,A.c = 2*n;

    A = QuickMult(k+1,A);

    for(i = 1;i <= n; ++i)
    {
        if(A.mat[i][i+n])
            A.mat[i][i+n]--;
        else
            A.mat[i][i+n] = m-1;
    }

    for(i = 1;i <= n; ++i)
    {
        for(j = 1;j <= n; ++j)
        {
            printf("%lld",A.mat[i][j+n]);
            if(j == n)
                printf("\n");
            else
                printf(" ");
        }
    }

    return 0;
}