0
点赞
收藏
分享

微信扫一扫

【算法设计与分析】【C++】斯特拉森Strassen矩阵乘法加速

8052cf60ff5c 2022-03-14 阅读 63

【算法设计与分析】【C++】斯特拉森Strassen矩阵乘法加速

矩阵乘法加速原理

在这里插入图片描述

  • 简单地来说,就是把普通矩阵的八次乘法优化成七次乘法,我们可以直接按照上面给出的式子
    来写。
  • 但是需要注意的是,进行矩阵分块和矩阵合并也需要消耗计算资源,显然在矩阵非常小时,乘法消耗的时间远小于各种分块后再计算的时间,因此程序中的BORDER_SIZE的选择对于程序的运行影响极大,经过本人测试下来,128是比较的好的数据
  • 接下来上代码(本代码仅限于n*n的矩阵并且n为2的幂次方
/*
author : TheSun
Time : 2022-03-04 00:44:17
*/
#include<bits/stdc++.h>
using namespace std;
const int BORDER_SIZE = 128; //在这个数量下面使用普通乘法,因为此时各种赋值所消耗的时间大于乘法运算的时间
struct Matrix
{
    int r, c;
    int **v;
    Matrix(int x=0 , int y=0);
    Matrix(const Matrix& others);
    ~Matrix();
    int* operator[](int n);
    Matrix operator *(Matrix others);//重载为加速矩阵乘法
    inline void setSize(int x , int y);
    inline void operator=(Matrix others);
    void Print();
    inline Matrix operator+(Matrix& others);
    inline Matrix operator-(Matrix& others);
    inline Matrix Muti(Matrix& others); //普通矩阵乘法
    inline void MatrixDiv(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22); // 矩阵分块
    inline Matrix MergeMatrix(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22); //合并矩阵
    //bool oprator==()
};

Matrix::Matrix(const Matrix& others)
{   
    //if()
    r = others.r , c = others.c;
    v = new int *[r + 1];
    for(int i=0 ; i<=r ; ++i)
        v[i] = new int [c + 1];
    for(int i=1 ; i<=r ; ++i)
        for(int j= 1 ; j<=c ; ++j)
            v[i][j] = 0;
}


Matrix Matrix::operator*(Matrix others)
{
    if( r <= BORDER_SIZE && c <= BORDER_SIZE && others.c <= BORDER_SIZE)
    {
        return this->Muti(others);
    }
    else
    {
        Matrix A_11(r/2 , c/2) , A_12(r/2 , c/2) , A_21(r/2 , c/2) , A_22(r/2 , c/2),
               B_11(others.r/2 , others.c/2) , B_12(others.r/2 , others.c/2) , B_21(others.r/2 , others.c/2) , 
               B_22(others.r/2 , others.c/2) ,
               C_11(r/2 , others.c/2) , C_12(r/2 , others.c/2) , C_21(r/2 , others.c/2) , C_22(r/2 , others.c/2);
        this->MatrixDiv(A_11 , A_12 , A_21 , A_22);
        others.MatrixDiv(B_11 , B_12 , B_21 , B_22);

        Matrix  M_1 = A_11 * (B_12 - B_22) , 
                M_2 = (A_11 + A_12) * B_22 ,
                M_3 = (A_21 + A_22) * B_11 ,
                M_4 = A_22 * (B_21 - B_11) , 
                M_5 = (A_11 + A_22) * (B_11 + B_22),
                M_6 = (A_12 - A_22) * (B_21 + B_22),
                M_7 = (A_11 - A_12) * (B_11 + B_12);

        C_11 = M_5 + M_4 - M_2 + M_6 ;
        C_12 = M_1 + M_2;
        C_21 = M_3 + M_4;
        C_22 = M_5 + M_1 - M_3 - M_7;

        Matrix res = this->MergeMatrix(C_11 , C_12 , C_21 , C_22);

        return res;
    }   
}

inline Matrix Matrix::operator+(Matrix& others)
{
    Matrix res(r , others.c);
    for(int i=1 ; i<=r ; ++i)
    {
        for(int j=1 ; j<=c ; ++j)
        {
            res[i][j] = v[i][j] + others[i][j];
        }
    }
    return res;
}

Matrix Matrix::operator-(Matrix& others)
{
    Matrix res(r , others.c);
    for(int i=1 ; i<=r ; ++i)
    {
        for(int j=1 ; j<=c ; ++j)
        {
            res[i][j] = v[i][j] - others[i][j];
        }
    }
    return res;
}

Matrix::Matrix(int x , int y)
{
    r = x , c = y;
    v = new int *[r + 1];
    for(int i=0 ; i<=r ; ++i)
        v[i] = new int [c + 1];
    for(int i=1 ; i<=r ; ++i)
        for(int j= 1 ; j<=c ; ++j)
            v[i][j] = 0;
}

Matrix::~Matrix()
{
    for(int i=0 ; i<=r ; ++i)
    {
        if(v[i] != nullptr ) delete[] v[i];
    }
    if(v != nullptr)
        delete[] v;
}

void Matrix::setSize(int x  , int y)
{
    for(int i=0 ; i<=r ; ++i)
    {
        if(v[i] != nullptr ) delete[] v[i];
    }
    if(v != nullptr)
        delete[] v;

    r = x , c = y;
    v = new int *[r + 1];
    for(int i=0 ; i<=r ; ++i)
        v[i] = new int [c + 1];

}

Matrix Matrix::Muti(Matrix& others)
{
    Matrix res(r, others.c);
    for (int i = 1; i <= r; ++i)
    {
        for (int j = 1; j <= c; ++j)
        {
            if(v[i][j])
            {
                for (int k = 1; k <= others.c; ++k)
                {
                    res[i][k] = res[i][k] + v[i][j] * others[j][k];
                }
            }
        }
    }
    return res;
}

int* Matrix::operator[](int n){return v[n];}

void Matrix::operator=(Matrix others)
{
    r  = others.r , c = others.c;
    this->setSize(r , c);
    for(int i=1 ; i<=r ; ++i)
        for(int j=1 ; j<=c ; ++j)
            v[i][j] = others[i][j];
}

void Matrix::Print()
{
    for (int i = 1; i <= r; ++i)
    {
        for (int j = 1; j <= c; ++j)
        {
            cout << setw(6) << v[i][j] ;
        }
        cout << "\n";
    }
    cout << '\n';
}



void Matrix::MatrixDiv(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22)
{   
    //debug(r , c);
    //debug(A_22.r , A_22.c , A_22.v.size() , A_22.v[0].size());
    for(int i=1 ; i<=r/2 ; ++i)
        for(int j=1 ; j<=c/2 ; ++j)
            A_11[i][j] = v[i][j];

    for(int i=1 ; i<=r/2 ; ++i)
        for(int j=c/2+1 ; j<=c ; ++j)
            A_12[i][j-c/2] = v[i][j];
    
    for(int i=r/2 + 1 ; i<=r ; ++i)
        for(int j=1 ; j<=c/2 ; ++j)
            A_21[i-r/2][j] = v[i][j];
        
    for(int i=r/2+1 ; i<=r ; ++i)
        for(int j=c/2+1 ; j<=c ; ++j)
            A_22[i-r/2][j-c/2] = v[i][j];
        
}

Matrix Matrix::MergeMatrix(Matrix& A_11 , Matrix& A_12 , Matrix& A_21 , Matrix& A_22)
{   
    int ri = A_11.r * 2 , ci =  A_11.c * 2;
    Matrix res( ri, ci);
    for(int i=1 ; i<=ri/2 ; ++i)
        for(int j=1 ; j<=ci/2 ; ++j)
            res[i][j] = A_11[i][j];

    for(int i=1 ; i<=ri/2 ; ++i)
        for(int j=ci/2+1 ; j<=ci ; ++j)
            res[i][j] = A_12[i][j-ci/2];

    for(int i=ri/2+1 ; i<=ri ; ++i)
        for(int j=1 ; j<=ci/2 ; ++j)
            res[i][j] = A_21[i-ri/2][j];

    for(int i=ri/2+1 ; i<=ri ; ++i)
        for(int j=ci/2+1 ; j<=ci ; ++j)
            res[i][j] = A_22[i-ri/2][j-ci/2];
    return res;
}

int rad(int x , int y)
{
    srand(x + y + time(NULL));
    return (x * y * rand()) % 100 + x + y;
}


inline void solve(int x)
{
    int r , c;
    r = c = x;
    Matrix A(r, c) , B(r,c);
    for(int i=1 ; i<=r ; ++i)
    {
        for(int j=1 ; j<=c ; ++j)
            {
                A[i][j] = rad(i,j);
                B[i][j] = rad(j,i);
            }
    }
    cout << x << ":\n";
    double t1 = clock();
    Matrix ans = A.Muti(B);
    double t2 = clock();
    cout << (t2 - t1)/CLOCKS_PER_SEC << "s\n"; 
    Matrix res = A * B ;
    double t3 = clock();
    cout << (t3 - t2)/CLOCKS_PER_SEC << "s\n";
    cout << '\n';
}

int main()
{   
    for(int i=8 ; i<=11 ; ++i)
    {
        solve((1<<i));
    }
}
  • 以下是本人测试出来的时间,在n逐渐变大的时候,越来越能体现Strassen乘法的速度优势
n25651210242048
普通乘法0.15s1.3s8.5s67.9s
加速乘法0.13s1.0s6.0s42.8s
举报

相关推荐

0 条评论