首页 > 代码库 > Poj 3233 Matrix Power Series(矩阵二分快速幂)

Poj 3233 Matrix Power Series(矩阵二分快速幂)

题目链接:http://poj.org/problem?id=3233

解题报告:输入一个边长为n的矩阵A,然后输入一个k,要你求A + A^2 + A^3 + A^4 + A^5.......A^k,然后结果的每个元素A[i][j] % m。(n <= 30,k < 10^9,m < 10^4)

要用到矩阵快速幂,但我认为最重要的其实还是相加的那个过程,因为k的范围是10^9,一个一个加肯定是不行的,我想了一个办法就是我以k = 8为例说明:

ans = A + A^2 + A^3 + A^4 + A^5 + A^6 + A^7 + A^8

= A + A^2 + A^3 + A^4 + A^4 * (A + A^2 + A^3 + A^4)   // 这样分成两块之后就只要算一次A + A^2 + A^3 + A^4,就可以得出最后结果,

而A + A^2 + A^3 + A^4这个又可以通过相同的方法划分成如下:

= A + A^2 + A^2 * (A + A^2)同理。。。。就可以在logn时间求出他们的和了,然后快速的求A^k次方是用二分矩阵快速幂这里就不说了。

  1 #include<cstdio>  2 #include<cstring>  3 #include<iostream>  4 #include<algorithm>  5 #include<deque>  6 #include<cmath>  7 using namespace std;  8 typedef __int64 INT;  9 const int N = 31; 10 int m; 11  12 struct node 13 { 14     INT t[N][N]; 15     int n; 16     friend node operator * (node a,node b) 17     { 18         node B = a; 19         for(int i = 0;i < a.n;++i) 20         for(int j = 0;j < a.n;++j) 21         { 22             int tot = 0; 23             for(int k = 0;k < a.n;++k) 24             tot += ((a.t[i][k] * b.t[k][j]) % m); 25             B.t[i][j] = tot % m; 26         } 27         return B; 28     } 29     friend node operator + (node a,node b) 30     { 31         node B; 32         B.n = a.n; 33         for(int i = 0;i < a.n;++i) 34         for(int j = 0;j < a.n;++j) 35         B.t[i][j] = (a.t[i][j] + b.t[i][j]) % m; 36         return B; 37     } 38 }; 39 node A; 40 void print(node t) 41 { 42     for(int i = 0;i < t.n;++i) 43     for(int j = 0;j < t.n;++j) 44     printf(j == t.n-1? "%d\n":"%d ",t.t[i][j]); 45 } 46 node Pw(node tt,int n)   //二分矩阵快速幂求A ^ n  47 { 48     if(n == 1) return A; 49     node res; 50     res.n = tt.n; 51     memset(res.t,0,sizeof(res.t)); 52     for(int i = 0;i < res.n;++i) 53     res.t[i][i] = 1; 54     while(n) 55     { 56         if(n & 1) 57         res = res * tt; 58         n >>= 1; 59         tt = tt * tt; 60     } 61     return res; 62 }     63  64 node get_ans(int n)    //求A^1 到 A ^ n的和  65 { 66     if(n == 1) 67     return A; 68     if(n & 1) 69     { 70         node temp1; 71         temp1.n = A.n; 72         temp1 = Pw(A,n); 73         node temp2; 74         temp2.n = 2; 75         temp2 = get_ans(n-1); 76         temp1 = temp1 + temp2; 77         return temp1; 78     } 79     else 80     { 81         node temp1; 82         temp1.n = A.n; 83         temp1  = Pw(A,n / 2); 84         node temp2 = get_ans(n / 2); 85         temp2.n = A.n; 86         temp1 = temp1 * temp2; 87         temp1 = temp1 + temp2; 88         return temp1; 89     } 90 } 91  92 int main() 93 { 94      95     int n,k; 96     while(scanf("%d%d%d",&n,&k,&m)!=EOF) 97     { 98         A.n = n; 99         for(int i = 0;i < n;++i)100         for(int j = 0;j < n;++j)101         scanf("%d",&A.t[i][j]);102         node ans = get_ans(k);103         print(ans);104     }105     return 0;106 } 
View Code