首页 > 代码库 > Conquer-Divide的经典例子之Strassen算法解决大型矩阵的相乘

Conquer-Divide的经典例子之Strassen算法解决大型矩阵的相乘

通过汉诺塔问题理解递归的精髓中我讲解了怎么把一个复杂的问题一步步recursively划分了成简单显而易见的小问题。其实这个解决问题的思路就是算法中常用的divide and conquer, 这篇日志通过解决矩阵的乘法,来了解另外一个基本divide and conque思想的strassen算法。

矩阵A乘以B等于X, 则Xij = 
注意左乘右乘的区别,AB 与BA是不同的。
如果r = 1, 直接就是两个数的相乘。
如果r = 2, 例如
X = 
[ 1, 2; 
  3, 4];
Y = 
[ 2, 3;
 4, 5];
R = XY的计算十分简单,但是如果r很大,耗时是O(r^3)。为了简化,可以把X, Y各自划分成2X2的矩阵,每一个元素其实是有n/2行的矩阵
(注:这里仅讲解行数等于列数的情况。)

X = 
[A, B;
C, D];

Y = 
[E, F;
G, H]

所以XY =[
AE+BG, AF+BH;
CE+DG, CF+DH]

Strassen引入seven magic product 分别是P1, P2, P3 ,P4, P5, P6, P7
P1 = A(F-H)
P2 = (A+B)H
P3 = (C+D)E
P4 = D(G-E)
P5 = (A+D)(E+H)
P6 = (B-D)(G+H)
P7 = (A-C)(E+F)

这样XY = 
[P5+P4-P2+P6, P1+P2;
P3+P4, P1+P5-P3-P7]

然后通过递归的策略计算矩阵的相乘,递归的出口是n = 1.

关键点就是这些,附上代码吧。

 

[java] view plaincopy在CODE上查看代码片派生到我的代码片
 
    1. //multiply matrix multiplication  
    2. import java.util.Scanner;  
    3. public class Strassen{  
    4.     public Strassen(){}  
    5.   
    6.   
    7.     /** split a parent matrix into child matrics8*/  
    8.     public static void split(int[][] P, int[][] C, int iB, int jB){  
    9.         for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)  
    10.             for(int j1=0, j2=jB; j1<C.length; j1++, j2++)  
    11.                 C[i1][j1] = P[i2][j2];  
    12.     }  
    13.   
    14.   
    15.     /**join child matric into parent matrix*/  
    16.     public static void join(int[][] C, int[][] P, int iB, int jB){  
    17.         for(int i1=0, i2 = iB; i1<C.length; i1++, i2++)  
    18.             for(int j1=0, j2=jB; j1<C.length; j1++, j2++)  
    19.                 P[i2][j2]=C[i1][j1];   
    20.     }  
    21.   
    22.   
    23.     /**add two matrics into one*/  
    24.     public static int[][] add(int[][] A, int[][] B){  
    25.         //A and B has the same dimension  
    26.         int n = A.length;  
    27.         int[][] C = new int[n][n];  
    28.         for (int i=0; i<n; i++)  
    29.             for(int j=0; j<n; j++)  
    30.                 C[i][j] = A[i][j] + B[i][j];  
    31.                   
    32.         return C;          
    33.     }  
    34.   
    35.   
    36.   
    37.   
    38.     //subtract one matric by another  
    39.     public static int[][] sub(int[][] A, int[][] B){  
    40.         //A and B has the same dimension  
    41.         int n = A.length;  
    42.         int[][] C = new int[n][n];  
    43.         for (int i=0; i<n; i++)  
    44.             for(int j=0; j<n; j++)  
    45.                 C[i][j] = A[i][j] - B[i][j];  
    46.         return C;     
    47.     }  
    48.   
    49.   
    50.     //Multiply matrix  
    51.     public static int[][] multiply(int[][] A, int[][] B){  
    52.         int n = A.length;  
    53.         int[][] R = new int[n][n];  
    54.   
    55.   
    56.         /**exit*/  
    57.         if(n==1)  
    58.             R[0][0] = A[0][0]+B[0][0];  
    59.   
    60.   
    61.         else{  
    62.             //divide A into 4 submatrix  
    63.             int[][] A11 = new int[n/2][n/2];  
    64.             int[][] A12 = new int[n/2][n/2];  
    65.             int[][] A21 = new int[n/2][n/2];  
    66.             int[][] A22 = new int[n/2][n/2];  
    67.   
    68.   
    69.             split(A, A11, 00);  
    70.             split(A, A12, 0, n/2);  
    71.             split(A, A21, n/20);  
    72.             split(A, A22, n/2, n/2);  
    73.   
    74.   
    75.             //divide B into 4 submatric  
    76.             int[][] B11 = new int[n/2][n/2];  
    77.             int[][] B12 = new int[n/2][n/2];  
    78.             int[][] B21 = new int[n/2][n/2];  
    79.             int[][] B22 = new int[n/2][n/2];  
    80.   
    81.   
    82.             split(B, B11, 00);  
    83.             split(B, B12, 0, n/2);  
    84.             split(B, B21, n/20);  
    85.             split(B, B22, n/2, n/2);  
    86.   
    87.   
    88.             //seven magic products  
    89.             int[][] P1 = multiply(A11, sub(B12, B22));  
    90.             int[][] P2 = multiply(add(A11,A12), B22);  
    91.             int[][] P3 = multiply(add(A21, A22), B11);  
    92.             int[][] P4 = multiply(A22, sub(B21, B11));  
    93.             int[][] P5 = multiply(add(A11, A22), add(B11, B22));  
    94.             int[][] P6 = multiply(sub(A12, A22), add(B21, B22));  
    95.             int[][] P7 = multiply(sub(A11, A21), add(B11, B12));  
    96.   
    97.   
    98.   
    99.   
    100.             //new 4 submatrix  
    101.             int[][] R11 = add(add(P5, sub(P4, P2)), P6);  
    102.             int[][] R12 = add(P1, P2);  
    103.             int[][] R21 = add(P3, P4);  
    104.             int[][] R22 = sub(sub(add(P1, P5), P3), P7);  
    105.   
    106.   
    107.             //joint together  
    108.             join(R11, R, 00);  
    109.             join(R12, R, 0, n/2);  
    110.             join(R21, R, n/20);  
    111.             join(R22, R, n/2, n/2);  
    112.   
    113.         }  
    114.         return R;  
    115.     }  
    116.   
    117.   
    118.     //main   
    119.     public static void main(String[] args){  
    120.           
    121.         Scanner scan = new Scanner(System.in);  
    122.         System.out.println("Strassen Multiplication Algorithm Test\n");  
    123.         Strassen s = new Strassen();  
    124.    
    125.   
    126.   
    127.         System.out.println("Fetch the matric A and B...");  
    128.         int N = scan.nextInt();  
    129.         int[][] A = new int[N][N];  
    130.         int[][] B = new int[N][N];  
    131.   
    132.   
    133.         for (int i = 0; i < N; i++)  
    134.             for (int j = 0; j < N; j++)  
    135.                 A[i][j] = scan.nextInt();  
    136.   
    137.   
    138.         for (int i = 0; i < N; i++)  
    139.             for (int j = 0; j < N; j++)  
    140.                 B[i][j] = scan.nextInt();  
    141.   
    142.   
    143.         System.out.println("Fetch Completed!");  
    144.    
    145.         int[][] C = s.multiply(A, B);  
    146.           
    147.         System.out.println("\nmatrices A = ");  
    148.         for (int i = 0; i < N; i++){  
    149.             for (int j = 0; j < N; j++)  
    150.                 System.out.print(A[i][j] +" ");  
    151.             System.out.println();  
    152.         }  
    153.   
    154.   
    155.         System.out.println("\nmatrices B =");  
    156.         for (int i = 0; i < N; i++) {  
    157.             for (int j = 0; j < N; j++)  
    158.                 System.out.print(B[i][j] +" ");  
    159.             System.out.println();  
    160.         }  
    161.    
    162.         System.out.println("\nProduct of matrices A and  B  = ");  
    163.         for (int i = 0; i < N; i++)  
    164.         {  
    165.             for (int j = 0; j < N; j++)  
    166.                 System.out.print(C[i][j] +" ");  
    167.             System.out.println();  
    168.         }  
    169.     }  
    170. }