0
点赞
收藏
分享

微信扫一扫

算法导论--------------Strassen矩阵乘法


由于看到 了动态规划来分析解决矩阵链乘的问题,所以回顾了一下矩阵乘法,发现这个知识点忘记的差不多了

,现在再来总结一下。

首先我们知道两个矩阵相乘A*B,那么A的列数必须等于B的行数,否则不能进行相乘.

首先我们来回顾一下解决矩阵相乘问题的一般方法:利用三个for循环来解决,时间复杂度为o(n^3)。

矩阵乘法定义:

例如有两个n乘以n的矩阵A和B,C=A*B;

那么求C的过程为:

#include<iostream>
using namespace std;
int main()
{
int a[3][4] = {
1, 3, 3, 1,
1, 2, 4, 1,
2, 6, 5, 1
};
int b[4][2] = {
2, 1,
2, 1,
2, 1,
2, 1
};

int c[3][2] = { 0 };

for (int i = 0; i < 3; ++i)
{
for (int j = 0; j < 2; ++j)
{
c[i][j] = 0;
for (int k = 0; k < 4; ++k)
c[i][j] = c[i][j] + a[i][k] * b[k][j];
}
}

for (int i = 0; i < 3; ++i)
{
int k = 0;
for (int j = 0; j < 2; ++j)
{
cout << c[i][j] << " ";
k++;
if (k == 2)
{
cout << endl;
}
}
}
return 0;
}

现在我们来看看Strassen算法的原理:

一般情况下矩阵乘法需要三个for循环,时间复杂度为O(n^3),现在我们将矩阵分块如图:( 来自MIT算法导论 )

                                        

算法导论--------------Strassen矩阵乘法_ios

一般算法需要八次乘法

r = a * e + b * g ;

s = a * f  + b * h ;

t = c * e + d  * g; 

u = c * f + d * h;

strassen将其变成7次乘法,因为大家都知道乘法比加减法消耗更多,所有时间复杂更高!

strassen的处理是:

令:

p1 = a * ( f - h )

p2 = ( a + b ) *  h

p3 = ( c +d ) * e

p4 = d *  ( g - e )

p5 = ( a + d ) * ( e + h )

p6 =  ( b - d ) * ( g + h ) 

p7 = ( a - c ) * ( e + f )

那么我们可以知道:

r  = p5 + p4 + p6 - p2

s = p1 + p2

t = p3 + p4

u = p5 + p1 - p3 - p7

我们可以看到上面只有7次乘法和多次加减法,最终达到降低复杂度为O( n^lg7 ) ~= O( n^2.81 );

其代码实现过程为:其中n必须为2的幂

#include<iostream>
using namespace std;

#define N 4

//matrix + matrix
void my_plus(int t[N / 2][N / 2], int r[N / 2][N / 2], int s[N / 2][N / 2])
{
int i, j;
for (i = 0; i < N / 2; i++)
{
for (j = 0; j < N / 2; j++)
{
t[i][j] = r[i][j] + s[i][j];
}
}
}

//matrix - matrix
void my_minus(int t[N / 2][N / 2], int r[N / 2][N / 2], int s[N / 2][N / 2])
{
int i, j;
for (i = 0; i < N / 2; i++)
{
for (j = 0; j < N / 2; j++)
{
t[i][j] = r[i][j] - s[i][j];
}
}
}

//matrix * matrix
void my_mul(int t[N / 2][N / 2], int r[N / 2][N / 2], int s[N / 2][N / 2])
{
int i, j, k;
for (i = 0; i < N / 2; i++)
{
for (j = 0; j < N / 2; j++)
{
t[i][j] = 0;
for (k = 0; k < N / 2; k++)
{
t[i][j] += r[i][k] * s[k][j];
}
}
}
}

int main()
{
int i, j, k;
int mat[N][N];
//int m1[N][N];
//int m2[N][N];
int a[N / 2][N / 2], b[N / 2][N / 2], c[N / 2][N / 2], d[N / 2][N / 2];
int e[N / 2][N / 2], f[N / 2][N / 2], g[N / 2][N / 2], h[N / 2][N / 2];
int p1[N / 2][N / 2], p2[N / 2][N / 2], p3[N / 2][N / 2], p4[N / 2][N / 2];
int p5[N / 2][N / 2], p6[N / 2][N / 2], p7[N / 2][N / 2];
int r[N / 2][N / 2], s[N / 2][N / 2], t[N / 2][N / 2], u[N / 2][N / 2], t1[N / 2][N / 2], t2[N / 2][N / 2];

int m1[4][4] = {
4, 4, 4, 4,
4, 4, 4, 4,
4, 4, 4, 4,
4, 4, 4, 4 };
int m2[4][4] = {
2, 2, 2, 2,
2, 2, 2, 2,
2, 2, 2, 2,
2, 2, 2, 2 };



// a b c d e f g h
for (int i = 0; i < N / 2; i++)
{
for (int j = 0; j < N / 2; j++)
{
a[i][j] = m1[i][j];
b[i][j] = m1[i][j + N / 2];
c[i][j] = m1[i + N / 2][j];
d[i][j] = m1[i + N / 2][j + N / 2];
e[i][j] = m2[i][j];
f[i][j] = m2[i][j + N / 2];
g[i][j] = m2[i + N / 2][j];
h[i][j] = m2[i + N / 2][j + N / 2];
}
}

//p1
my_minus(r, f, h);
my_mul(p1, a, r);

//p2
my_plus(r, a, b);
my_mul(p2, r, h);

//p3
my_plus(r, c, d);
my_mul(p3, r, e);

//p4
my_minus(r, g, e);
my_mul(p4, d, r);

//p5
my_plus(r, a, d);
my_plus(s, e, f);
my_mul(p5, r, s);

//p6
my_minus(r, b, d);
my_plus(s, g, h);
my_mul(p6, r, s);

//p7
my_minus(r, a, c);
my_plus(s, e, f);
my_mul(p7, r, s);

//r = p5 + p4 - p2 + p6
my_plus(t1, p5, p4);
my_minus(t2, t1, p2);
my_plus(r, t2, p6);

//s = p1 + p2
my_plus(s, p1, p2);

//t = p3 + p4
my_plus(t, p3, p4);

//u = p5 + p1 - p3 - p7 = p5 + p1 - ( p3 + p7 )
my_plus(t1, p5, p1);
my_plus(t2, p3, p7);
my_minus(u, t1, t2);

for (int i = 0; i < N / 2; i++)
{
for (int j = 0; j < N / 2; j++)
{
mat[i][j] = r[i][j];
mat[i][j + N / 2] = s[i][j];
mat[i + N / 2][j] = t[i][j];
mat[i + N / 2][j + N / 2] = u[i][j];
}
}

printf("\n下面是strassen算法处理结果:\n");
for (i = 0; i < N; i++)
{
for (j = 0; j < N; j++)
{
printf("%d ", mat[i][j]);
}
printf("\n");
}

//下面是朴素算法处理
printf("\n下面是朴素算法处理结果:\n");
for (i = 0; i < N; i++)
{
for (j = 0; j < N; j++)
{
mat[i][j] = 0;
for (k = 0; k < N; k++)
{
mat[i][j] += m1[i][j] * m2[i][j];
}
}
}

for (i = 0; i < N; i++)
{
for (j = 0; j < N; j++)
{
printf("%d ", mat[i][j]);
}
printf("\n");
}

return 0;
}


举报

相关推荐

0 条评论