0
点赞
收藏
分享

微信扫一扫

hdu4965 巧用矩阵乘法结合律

题意:
     给两个矩阵,n*m的矩阵A,和m*n的矩阵B,
求(A*B)^(n*n)其中 m<=6,n<=1000。
思路:
      一开始直接模拟,写了个矩阵快速幂,超时了,因为A*B后得到的是1000*1000的矩阵,做乘法直接超时了,后来写了个这样的
    (A*B)^(n*n) = (A*B)*(A*B)*(A*B)...
                       = A * (B*A)*(B*A)*(B*A)...*B

矩阵虽然没有交换律但是有结合律,我们直接先B*A(得到的是一个最大6*6的矩阵)然后快速幂,然后再A * BA^(n*n-1) * B这样就行了,然后又超时了,算了很多次,感觉不可能超时,但还是超时了,原因就是我所有的矩阵用的都是mat[1002][1002]为了方便我都开结构体了,结果各种超时,最后没办法了,全都开数组,然后去模拟,A[1002][8],B[8][1002],BA[8][8]...,这样就AC了,难道开大的数组也会浪费很多时间?(这个地方头一次碰到)。



#include<stdio.h>
#include<string.h>

typedef struct
{
int mat[8][8];
}AA;

int A[1002][8] ,B[8][1002] ,C[1002][1002];
int nmm[1002][8];

AA mat_matba(int n ,int m)
{
AA c;
memset(c.mat ,0 ,sizeof(c.mat));
for(int k = 1 ;k <= n ;k ++)
for(int i = 1 ;i <= m ;i ++)
if(B[i][k])
for(int j = 1 ;j <= m ;j ++)
c.mat[i][j] = (c.mat[i][j] + B[i][k] * A[k][j])%6 ;
return c;
}

AA mat_mat(AA a ,AA b ,int n)
{
AA c;
memset(c.mat ,0 ,sizeof(c.mat));
for(int k = 1 ;k <= n ;k ++)
for(int i = 1 ;i <= n ;i ++)
if(a.mat[i][k])
for(int j = 1 ;j <= n ;j ++)
c.mat[i][j] = (c.mat[i][j] + a.mat[i][k] * b.mat[k][j]) % 6;
return c;
}


AA quick_mat(AA a ,int b ,int n)
{
AA c;
memset(c.mat ,0 ,sizeof(c.mat));
for(int i = 1 ;i <= n ;i ++)
c.mat[i][i] = 1;
while(b)
{
if(b&1) c = mat_mat(c ,a ,n);
a = mat_mat(a ,a ,n);
b >>= 1;
}
return c;
}

void mat_matnmm(AA mm ,int n ,int m)
{
memset(nmm ,0 ,sizeof(nmm));
for(int k = 1 ;k <= m ;k ++)
for(int i = 1 ;i <= n ;i ++)
if(A[i][k])
for(int j = 1 ;j <= m ;j ++)
nmm[i][j] = (nmm[i][j] + A[i][k] * mm.mat[k][j]) % 6;
}

void mat_matnmmn(int n ,int m)
{
memset(C ,0 ,sizeof(C));
for(int k = 1 ;k <= m ;k ++)
for(int i = 1 ;i <= n ;i ++)
for(int j = 1 ;j <= n ;j ++)
C[i][j] = (C[i][j] + nmm[i][k] * B[k][j]) % 6;
}



int main ()
{
int n ,m ,i ,j;
while(~scanf("%d %d" ,&n ,&m) && n + m)
{
for(i = 1 ;i <= n ;i ++)
for(j = 1 ;j <= m ;j ++)
scanf("%d" ,&A[i][j]);
for(i = 1 ;i <= m ;i ++)
for(j = 1 ;j <= n ;j ++)
scanf("%d" ,&B[i][j]);
AA c = mat_matba(n ,m);
AA ban = quick_mat(c ,n*n-1 ,m);
mat_matnmm(ban ,n ,m);
mat_matnmmn(n ,m);


int sum = 0;
for(i = 1 ;i <= n ;i ++)
for(j = 1 ;j <= n ;j ++)
sum += C[i][j];
printf("%d\n" ,sum);
}
return 0;
}




举报

相关推荐

0 条评论