首页 > 代码库 > HDU4965-Fast Matrix Calculation(矩阵快速幂)

HDU4965-Fast Matrix Calculation(矩阵快速幂)

题目链接


题意:n*k的矩阵A和一个k*n的矩阵B,C = A * B。求M = (C)^(n * n)时,矩阵M中每个元素的和(每个元素都要MOD6)

思路:因为n最大到1000,所以不能直接用矩阵快速幂求AB的n*n次幂,但是可以将公式稍微转换下,M = AB * AB...* AB = A * (BA) *... * (BA) * B,这样BA的n*n -1次幂就能用矩阵快速幂求解,之后再分别乘以A,B即可。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 1005;
const int MOD = 6;

struct mat{
    int s[6][6];
    mat() {
        memset(s, 0, sizeof(s)); 
    }
    mat operator * (const mat& c) {
        mat ans; 
        memset(ans.s, 0, sizeof(ans.s));
        for (int i = 0; i < 6; i++)
            for (int j = 0; j < 6; j++)
                for (int k = 0; k < 6; k++)
                    ans.s[i][j] = (ans.s[i][j] + s[i][k] * c.s[k][j]) % MOD;
        return ans;
    }
};

int a[N][6], b[6][N], temp[N][N], sum[N][N];
int n, m;

void init() {
    memset(a, 0, sizeof(a));
    memset(b, 0, sizeof(b));
    for (int i = 0; i < n; i++)
        for (int j = 0; j < m; j++)
            scanf("%d", &a[i][j]);
    for (int i = 0; i < m; i++)
        for (int j = 0; j < n; j++)
            scanf("%d", &b[i][j]);
}

mat pow_mod(mat c, int k) {
    if (k == 1)
        return c;
    mat a = pow_mod(c, k / 2);
    mat ans = a * a;
    if (k % 2)
        ans = ans * c;
    return ans;
}

int main() {
    while (scanf("%d%d", &n, &m)) {
        if (n == 0 && m == 0) 
            break;
        init();
        mat c;
        for (int i = 0; i < m; i++)
            for (int j = 0; j < m; j++)
                for (int k = 0; k < n; k++)
                    c.s[i][j] = (c.s[i][j] + b[i][k] * a[k][j]) % MOD;
        int cnt = n * n - 1;
        mat ans = pow_mod(c, cnt);

        memset(temp, 0, sizeof(temp));
        for (int i = 0; i < n; i++) 
            for (int j = 0; j < m; j++)
                for (int k = 0; k < m; k++) 
                    temp[i][j] = (temp[i][j] + a[i][k] * ans.s[k][j]) % MOD;
        cnt = 0;
        memset(sum, 0, sizeof(sum));
        for (int i = 0; i < n; i++) 
            for (int j = 0; j < n; j++) {
                for (int k = 0; k < m; k++) 
                    sum[i][j] = (sum[i][j] + temp[i][k] * b[k][j]) % MOD;
                cnt += sum[i][j];
            }
        printf("%d\n", cnt); 
    }
    return 0;
}


HDU4965-Fast Matrix Calculation(矩阵快速幂)