首页 > 代码库 > strassen算法——矩阵乘法

strassen算法——矩阵乘法

strassen算法可以看做是分治递归法求解矩阵乘法的改进。

利用分治递归法求解矩阵乘法的过程大致:

矩阵C = A * B(A、B、C都是n x n矩阵)


可以发现(A11 * B11)、(A12 * B21)……等子矩阵的乘法运算需要继续递归。上面有8个乘法,所以需要递归8次。
时间复杂度关系公式 T(n) = 8T(n/2) + O(n^2),这里8T(n/2)8次递归,O(n^2)是求C11,C12,C21,C22所做的加法,因为(A11*B11)、(A12*B21)……都有n^2 / 4个元素。
通过推导公式T(n) = 8T(n/2) + O(n^2),我们会得到T(n) = O(n^3),等等……这不是跟普通的3次循环做矩阵乘法的方法时间复杂度一样吗。

不要着急,算法还有改善的空间。
请看上面红框所选的部分,是不是可以想办法减少一下乘法的数量,8次乘法就代表8次递归。8次乘法中有很多乘法因子都是相同的……
这确实是个不小的挑战,很考验数学功底。笔者写满了一张演草纸也没有找到好方法。看了下strassen算法的实现,也只省去了一个乘法,还剩7个。多少有些失望……

Strassen算法的简化思路是——尽量用加减来代替乘法

推导结果:

C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7

P1 = A11 * S1
P2 = S2 * B22
P3 = S3 * B11
P4 = A22 * S4
P5 = S5 * S6
P6 = S7 * S8
P7 = S9 * S10

S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12

上面的计算过程有7个乘法,分别是计算P1、P2……

你可能会向笔者一样思考,这种方法是怎么想出来的呢?他是经过怎样的反复试验和推敲才找到了这种思路呢?我想,这只有他自己最清楚了。就像数学公式一样,我们只要知道结论就可以了,这并不妨碍我们以后的应用。

经过这一折腾,算法的时间复杂度公式变为T(n) = 7T(n/2) + O(n^2),经过推导得出T(n) = O(n^lg7)lg7约为2.81,所以最终T(n) = O(n^2.81)

strassen算法——矩阵乘法