首页 > 代码库 > Poj 3233 Matrix Power Series(矩阵二分快速幂)
Poj 3233 Matrix Power Series(矩阵二分快速幂)
题目链接:http://poj.org/problem?id=3233
解题报告:输入一个边长为n的矩阵A,然后输入一个k,要你求A + A^2 + A^3 + A^4 + A^5.......A^k,然后结果的每个元素A[i][j] % m。(n <= 30,k < 10^9,m < 10^4)
要用到矩阵快速幂,但我认为最重要的其实还是相加的那个过程,因为k的范围是10^9,一个一个加肯定是不行的,我想了一个办法就是我以k = 8为例说明:
ans = A + A^2 + A^3 + A^4 + A^5 + A^6 + A^7 + A^8
= A + A^2 + A^3 + A^4 + A^4 * (A + A^2 + A^3 + A^4) // 这样分成两块之后就只要算一次A + A^2 + A^3 + A^4,就可以得出最后结果,
而A + A^2 + A^3 + A^4这个又可以通过相同的方法划分成如下:
= A + A^2 + A^2 * (A + A^2)同理。。。。就可以在logn时间求出他们的和了,然后快速的求A^k次方是用二分矩阵快速幂这里就不说了。
1 #include<cstdio> 2 #include<cstring> 3 #include<iostream> 4 #include<algorithm> 5 #include<deque> 6 #include<cmath> 7 using namespace std; 8 typedef __int64 INT; 9 const int N = 31; 10 int m; 11 12 struct node 13 { 14 INT t[N][N]; 15 int n; 16 friend node operator * (node a,node b) 17 { 18 node B = a; 19 for(int i = 0;i < a.n;++i) 20 for(int j = 0;j < a.n;++j) 21 { 22 int tot = 0; 23 for(int k = 0;k < a.n;++k) 24 tot += ((a.t[i][k] * b.t[k][j]) % m); 25 B.t[i][j] = tot % m; 26 } 27 return B; 28 } 29 friend node operator + (node a,node b) 30 { 31 node B; 32 B.n = a.n; 33 for(int i = 0;i < a.n;++i) 34 for(int j = 0;j < a.n;++j) 35 B.t[i][j] = (a.t[i][j] + b.t[i][j]) % m; 36 return B; 37 } 38 }; 39 node A; 40 void print(node t) 41 { 42 for(int i = 0;i < t.n;++i) 43 for(int j = 0;j < t.n;++j) 44 printf(j == t.n-1? "%d\n":"%d ",t.t[i][j]); 45 } 46 node Pw(node tt,int n) //二分矩阵快速幂求A ^ n 47 { 48 if(n == 1) return A; 49 node res; 50 res.n = tt.n; 51 memset(res.t,0,sizeof(res.t)); 52 for(int i = 0;i < res.n;++i) 53 res.t[i][i] = 1; 54 while(n) 55 { 56 if(n & 1) 57 res = res * tt; 58 n >>= 1; 59 tt = tt * tt; 60 } 61 return res; 62 } 63 64 node get_ans(int n) //求A^1 到 A ^ n的和 65 { 66 if(n == 1) 67 return A; 68 if(n & 1) 69 { 70 node temp1; 71 temp1.n = A.n; 72 temp1 = Pw(A,n); 73 node temp2; 74 temp2.n = 2; 75 temp2 = get_ans(n-1); 76 temp1 = temp1 + temp2; 77 return temp1; 78 } 79 else 80 { 81 node temp1; 82 temp1.n = A.n; 83 temp1 = Pw(A,n / 2); 84 node temp2 = get_ans(n / 2); 85 temp2.n = A.n; 86 temp1 = temp1 * temp2; 87 temp1 = temp1 + temp2; 88 return temp1; 89 } 90 } 91 92 int main() 93 { 94 95 int n,k; 96 while(scanf("%d%d%d",&n,&k,&m)!=EOF) 97 { 98 A.n = n; 99 for(int i = 0;i < n;++i)100 for(int j = 0;j < n;++j)101 scanf("%d",&A.t[i][j]);102 node ans = get_ans(k);103 print(ans);104 }105 return 0;106 }
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。