【算法设计与分析】【C++】斯特拉森Strassen矩阵乘法加速
矩阵乘法加速原理
- 简单地来说,就是把普通矩阵的八次乘法优化成七次乘法,我们可以直接按照上面给出的式子
来写。 - 但是需要注意的是,进行矩阵分块和矩阵合并也需要消耗计算资源,显然在矩阵非常小时,乘法消耗的时间远小于各种分块后再计算的时间,因此程序中的BORDER_SIZE的选择对于程序的运行影响极大,经过本人测试下来,128是比较的好的数据
- 接下来上代码(本代码仅限于n*n的矩阵并且n为2的幂次方)
/*
author : TheSun
Time : 2022-03-04 00:44:17
*/
#include<bits/stdc++.h>
using namespace std;
const int BORDER_SIZE = 128; //在这个数量下面使用普通乘法,因为此时各种赋值所消耗的时间大于乘法运算的时间
struct Matrix
{
int r, c;
int **v;
Matrix(int x=0 , int y=0);
Matrix(const Matrix& others);
~Matrix();
int* operator[](int n);
Matrix operator *(Matrix others);//重载为加速矩阵乘法
inline void setSize(int x , int y);
inline void operator=(Matrix others);
void Print();
inline Matrix operator+(Matrix& others);
inline Matrix operator-(Matrix& others);
inline Matrix Muti(Matrix& others); //普通矩阵乘法
inline void MatrixDiv(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22); // 矩阵分块
inline Matrix MergeMatrix(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22); //合并矩阵
//bool oprator==()
};
Matrix::Matrix(const Matrix& others)
{
//if()
r = others.r , c = others.c;
v = new int *[r + 1];
for(int i=0 ; i<=r ; ++i)
v[i] = new int [c + 1];
for(int i=1 ; i<=r ; ++i)
for(int j= 1 ; j<=c ; ++j)
v[i][j] = 0;
}
Matrix Matrix::operator*(Matrix others)
{
if( r <= BORDER_SIZE && c <= BORDER_SIZE && others.c <= BORDER_SIZE)
{
return this->Muti(others);
}
else
{
Matrix A_11(r/2 , c/2) , A_12(r/2 , c/2) , A_21(r/2 , c/2) , A_22(r/2 , c/2),
B_11(others.r/2 , others.c/2) , B_12(others.r/2 , others.c/2) , B_21(others.r/2 , others.c/2) ,
B_22(others.r/2 , others.c/2) ,
C_11(r/2 , others.c/2) , C_12(r/2 , others.c/2) , C_21(r/2 , others.c/2) , C_22(r/2 , others.c/2);
this->MatrixDiv(A_11 , A_12 , A_21 , A_22);
others.MatrixDiv(B_11 , B_12 , B_21 , B_22);
Matrix M_1 = A_11 * (B_12 - B_22) ,
M_2 = (A_11 + A_12) * B_22 ,
M_3 = (A_21 + A_22) * B_11 ,
M_4 = A_22 * (B_21 - B_11) ,
M_5 = (A_11 + A_22) * (B_11 + B_22),
M_6 = (A_12 - A_22) * (B_21 + B_22),
M_7 = (A_11 - A_12) * (B_11 + B_12);
C_11 = M_5 + M_4 - M_2 + M_6 ;
C_12 = M_1 + M_2;
C_21 = M_3 + M_4;
C_22 = M_5 + M_1 - M_3 - M_7;
Matrix res = this->MergeMatrix(C_11 , C_12 , C_21 , C_22);
return res;
}
}
inline Matrix Matrix::operator+(Matrix& others)
{
Matrix res(r , others.c);
for(int i=1 ; i<=r ; ++i)
{
for(int j=1 ; j<=c ; ++j)
{
res[i][j] = v[i][j] + others[i][j];
}
}
return res;
}
Matrix Matrix::operator-(Matrix& others)
{
Matrix res(r , others.c);
for(int i=1 ; i<=r ; ++i)
{
for(int j=1 ; j<=c ; ++j)
{
res[i][j] = v[i][j] - others[i][j];
}
}
return res;
}
Matrix::Matrix(int x , int y)
{
r = x , c = y;
v = new int *[r + 1];
for(int i=0 ; i<=r ; ++i)
v[i] = new int [c + 1];
for(int i=1 ; i<=r ; ++i)
for(int j= 1 ; j<=c ; ++j)
v[i][j] = 0;
}
Matrix::~Matrix()
{
for(int i=0 ; i<=r ; ++i)
{
if(v[i] != nullptr ) delete[] v[i];
}
if(v != nullptr)
delete[] v;
}
void Matrix::setSize(int x , int y)
{
for(int i=0 ; i<=r ; ++i)
{
if(v[i] != nullptr ) delete[] v[i];
}
if(v != nullptr)
delete[] v;
r = x , c = y;
v = new int *[r + 1];
for(int i=0 ; i<=r ; ++i)
v[i] = new int [c + 1];
}
Matrix Matrix::Muti(Matrix& others)
{
Matrix res(r, others.c);
for (int i = 1; i <= r; ++i)
{
for (int j = 1; j <= c; ++j)
{
if(v[i][j])
{
for (int k = 1; k <= others.c; ++k)
{
res[i][k] = res[i][k] + v[i][j] * others[j][k];
}
}
}
}
return res;
}
int* Matrix::operator[](int n){return v[n];}
void Matrix::operator=(Matrix others)
{
r = others.r , c = others.c;
this->setSize(r , c);
for(int i=1 ; i<=r ; ++i)
for(int j=1 ; j<=c ; ++j)
v[i][j] = others[i][j];
}
void Matrix::Print()
{
for (int i = 1; i <= r; ++i)
{
for (int j = 1; j <= c; ++j)
{
cout << setw(6) << v[i][j] ;
}
cout << "\n";
}
cout << '\n';
}
void Matrix::MatrixDiv(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22)
{
//debug(r , c);
//debug(A_22.r , A_22.c , A_22.v.size() , A_22.v[0].size());
for(int i=1 ; i<=r/2 ; ++i)
for(int j=1 ; j<=c/2 ; ++j)
A_11[i][j] = v[i][j];
for(int i=1 ; i<=r/2 ; ++i)
for(int j=c/2+1 ; j<=c ; ++j)
A_12[i][j-c/2] = v[i][j];
for(int i=r/2 + 1 ; i<=r ; ++i)
for(int j=1 ; j<=c/2 ; ++j)
A_21[i-r/2][j] = v[i][j];
for(int i=r/2+1 ; i<=r ; ++i)
for(int j=c/2+1 ; j<=c ; ++j)
A_22[i-r/2][j-c/2] = v[i][j];
}
Matrix Matrix::MergeMatrix(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22)
{
int ri = A_11.r * 2 , ci = A_11.c * 2;
Matrix res( ri, ci);
for(int i=1 ; i<=ri/2 ; ++i)
for(int j=1 ; j<=ci/2 ; ++j)
res[i][j] = A_11[i][j];
for(int i=1 ; i<=ri/2 ; ++i)
for(int j=ci/2+1 ; j<=ci ; ++j)
res[i][j] = A_12[i][j-ci/2];
for(int i=ri/2+1 ; i<=ri ; ++i)
for(int j=1 ; j<=ci/2 ; ++j)
res[i][j] = A_21[i-ri/2][j];
for(int i=ri/2+1 ; i<=ri ; ++i)
for(int j=ci/2+1 ; j<=ci ; ++j)
res[i][j] = A_22[i-ri/2][j-ci/2];
return res;
}
int rad(int x , int y)
{
srand(x + y + time(NULL));
return (x * y * rand()) % 100 + x + y;
}
inline void solve(int x)
{
int r , c;
r = c = x;
Matrix A(r, c) , B(r,c);
for(int i=1 ; i<=r ; ++i)
{
for(int j=1 ; j<=c ; ++j)
{
A[i][j] = rad(i,j);
B[i][j] = rad(j,i);
}
}
cout << x << ":\n";
double t1 = clock();
Matrix ans = A.Muti(B);
double t2 = clock();
cout << (t2 - t1)/CLOCKS_PER_SEC << "s\n";
Matrix res = A * B ;
double t3 = clock();
cout << (t3 - t2)/CLOCKS_PER_SEC << "s\n";
cout << '\n';
}
int main()
{
for(int i=8 ; i<=11 ; ++i)
{
solve((1<<i));
}
}
- 以下是本人测试出来的时间,在n逐渐变大的时候,越来越能体现Strassen乘法的速度优势
n | 256 | 512 | 1024 | 2048 |
---|---|---|---|---|
普通乘法 | 0.15s | 1.3s | 8.5s | 67.9s |
加速乘法 | 0.13s | 1.0s | 6.0s | 42.8s |