#include #include #include const int BLOCK_SIZE = 4; void sgemm( int m_a, int n_a, float *A, float *B, float *C ); void matrixMultiply( int m_a, int n_a, float *A, float *B, float *C ); float *convertToRowMajor( float *matrix, int oldRows, int oldCols, const int newRows, const int newCols ); float *convertToColMajor( float *matrix, int oldRows, int oldCols, const int newRows, const int newCols ); void compactAndConvertToColMajor( float *input, float *output, int originalRows, int currentCols ); void optimize36x36( int m_a, int n_a, float *A, float *B, float *C ); void sgemm( int m_a, int n_a, float *A, float *B, float *C ) { if( m_a == 36 && n_a == 36 ) { optimize36x36( m_a, n_a, A, B, C ); } else { matrixMultiply( m_a, n_a, A, B, C ); } } void matrixMultiply( int m_a, int n_a, float *A, float *B, float *C ) { int V_BLOCKS = 0, H_BLOCKS = 0; int i, j, k; int a_offset = 0, b_offset = 0, c_offset = 0; int mDim = 0, nDim = 0; float *rowA = NULL, *colB = NULL; float *rowC = C; __m128 mA1, mB1, mA2, mB2, mA3, mB3, mA4, mB4; __m128 tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8; mDim = ((m_a % BLOCK_SIZE) ? (BLOCK_SIZE - (m_a % BLOCK_SIZE)) : 0) + m_a; nDim = ((n_a % BLOCK_SIZE) ? (BLOCK_SIZE - (n_a % BLOCK_SIZE)) : 0) + n_a; rowA = convertToRowMajor( A, m_a, n_a, mDim, nDim ); colB = convertToColMajor( B, n_a, m_a, nDim, mDim ); rowC = (float*) calloc( mDim * mDim, sizeof(float) ); V_BLOCKS = mDim / BLOCK_SIZE; H_BLOCKS = nDim / BLOCK_SIZE; for( i = 0; i < H_BLOCKS; ++i ) { for( j = 0; j < V_BLOCKS; ++j ) { for( k = 0; k < V_BLOCKS; ++k ) { a_offset = (BLOCK_SIZE * i) + (BLOCK_SIZE * BLOCK_SIZE) * H_BLOCKS * j; b_offset = (BLOCK_SIZE * i) + (BLOCK_SIZE * BLOCK_SIZE) * H_BLOCKS * k; c_offset = (BLOCK_SIZE * k) + (BLOCK_SIZE * BLOCK_SIZE) * V_BLOCKS * j; mA1 = _mm_loadu_ps( rowA + a_offset ); mA2 = _mm_loadu_ps( rowA + a_offset + nDim ); mA3 = _mm_loadu_ps( rowA + a_offset + 2 * nDim ); mA4 = _mm_loadu_ps( rowA + a_offset + 3 * nDim ); mB1 = _mm_loadu_ps( colB + b_offset ); mB2 = _mm_loadu_ps( colB + b_offset + nDim ); mB3 = _mm_loadu_ps( colB + b_offset + 2 * nDim ); mB4 = _mm_loadu_ps( colB + b_offset + 3 * nDim ); tmp1 = _mm_mul_ps( mA1, mB1 ); tmp2 = _mm_mul_ps( mA1, mB2 ); tmp5 = _mm_hadd_ps( tmp1, tmp2 ); tmp3 = _mm_mul_ps( mA1, mB3 ); tmp4 = _mm_mul_ps( mA1, mB4 ); tmp6 = _mm_hadd_ps( tmp3, tmp4 ); tmp8 = _mm_hadd_ps( tmp5, tmp6 ); _mm_storeu_ps( rowC + c_offset, _mm_add_ps( _mm_loadu_ps( rowC + c_offset ), tmp8 ) ); tmp1 = _mm_mul_ps( mA2, mB1 ); tmp2 = _mm_mul_ps( mA2, mB2 ); tmp5 = _mm_hadd_ps( tmp1, tmp2 ); tmp3 = _mm_mul_ps( mA2, mB3 ); tmp4 = _mm_mul_ps( mA2, mB4 ); tmp6 = _mm_hadd_ps( tmp3, tmp4 ); tmp8 = _mm_hadd_ps( tmp5, tmp6 ); _mm_storeu_ps( rowC + c_offset + mDim, _mm_add_ps( _mm_loadu_ps( rowC + c_offset + mDim ), tmp8 ) ); tmp1 = _mm_mul_ps( mA3, mB1 ); tmp2 = _mm_mul_ps( mA3, mB2 ); tmp5 = _mm_hadd_ps( tmp1, tmp2 ); tmp3 = _mm_mul_ps( mA3, mB3 ); tmp4 = _mm_mul_ps( mA3, mB4 ); tmp6 = _mm_hadd_ps( tmp3, tmp4 ); tmp8 = _mm_hadd_ps( tmp5, tmp6 ); _mm_storeu_ps( rowC + c_offset + 2 * mDim, _mm_add_ps( _mm_loadu_ps( rowC + c_offset + 2 * mDim ), tmp8 ) ); tmp1 = _mm_mul_ps( mA4, mB1 ); tmp2 = _mm_mul_ps( mA4, mB2 ); tmp5 = _mm_hadd_ps( tmp1, tmp2 ); tmp3 = _mm_mul_ps( mA4, mB3 ); tmp4 = _mm_mul_ps( mA4, mB4 ); tmp6 = _mm_hadd_ps( tmp3, tmp4 ); tmp8 = _mm_hadd_ps( tmp5, tmp6 ); _mm_storeu_ps( rowC + c_offset + 3 * mDim, _mm_add_ps( _mm_loadu_ps( rowC + c_offset + 3 * mDim ), tmp8 ) ); } } } compactAndConvertToColMajor( rowC, C, m_a, mDim ); free( rowA ); free( colB ); free( rowC ); } float *convertToRowMajor( float *matrix, int oldRows, int oldCols, const int newRows, const int newCols ) { float *rowMajor = NULL; int i, j; rowMajor = (float*) calloc( newRows * newCols, sizeof(float) ); for( i = 0; i < oldRows; ++i ) { for( j = 0; j < oldCols; ++j ) { *(rowMajor + (i * newCols) + j ) = *(matrix + i + (j * oldRows)); } } return rowMajor; } float *convertToColMajor( float *matrix, int oldRows, int oldCols, const int newRows, const int newCols ) { float *colMajor = NULL; int i, j; colMajor = (float*) calloc( newRows * newCols, sizeof(float) ); for( i = 0; i < oldRows; ++i ) { for( j = 0; j < oldCols; ++j ) { *(colMajor + i + (j * newRows) ) = *(matrix + (i * oldCols) + j); } } return colMajor; } void compactAndConvertToColMajor( float *input, float *output, int originalRows, int currentCols ) { int i, j; for( i = 0; i < originalRows; ++i ) { for( j = 0; j < originalRows; ++j ) { *(output + i + (j * originalRows) ) = *(input + (i * currentCols) + j); } } } void optimize36x36( int m_a, int n_a, float *A, float *B, float *C ) { // J******** M********'s code __m128 c, a, b1,b2,b3,b4,b5,b6,b7,b8; int k, j, i, jMul, kMul; for( j = 0; j < 36; j++) { jMul = j*36; for( k = 0; k < 32; k+= 8) { kMul = k*36; b1 = _mm_load1_ps(B+ j + kMul); b2 = _mm_load1_ps(B+ j + kMul + 36); b3 = _mm_load1_ps(B+ j + kMul + 72); b4 = _mm_load1_ps(B+ j + kMul + 108); b5 = _mm_load1_ps(B + j + kMul + 144); b6 = _mm_load1_ps(B + j + kMul + 180); b7 = _mm_load1_ps(B + j + kMul + 216); b8 = _mm_load1_ps(B + j + kMul + 252); for( i = 0; i<36; i+=4){ c = _mm_loadu_ps(C + i + jMul); a = _mm_loadu_ps(A + i + kMul); c = _mm_add_ps(c, _mm_mul_ps(a, b1)); a = _mm_loadu_ps(A + i + kMul + 36); c = _mm_add_ps(c, _mm_mul_ps(a, b2)); a = _mm_loadu_ps(A + i + kMul + 72); c = _mm_add_ps(c, _mm_mul_ps(a, b3)); a = _mm_loadu_ps(A + i + kMul + 108); c = _mm_add_ps(c, _mm_mul_ps(a, b4)); a = _mm_loadu_ps(A + i + kMul + 144); c = _mm_add_ps(c, _mm_mul_ps(a, b5)); a = _mm_loadu_ps(A + i + kMul + 180); c = _mm_add_ps(c, _mm_mul_ps(a, b6)); a = _mm_loadu_ps(A + i + kMul + 216); c = _mm_add_ps(c, _mm_mul_ps(a, b7)); a = _mm_loadu_ps(A + i + kMul + 252); c = _mm_add_ps(c, _mm_mul_ps(a, b8)); _mm_storeu_ps(C + i + jMul, c); } } for( k = 32; k < 36; k+=4 ) { kMul = k*36; b1 = _mm_load1_ps(B+ j + kMul); b2 = _mm_load1_ps(B+ j + kMul + 36); b3 = _mm_load1_ps(B+ j + kMul + 72); b4 = _mm_load1_ps(B+ j + kMul + 108); for( i = 0; i < 36; i+=4 ) { c = _mm_loadu_ps(C + i + jMul); a = _mm_loadu_ps(A + i + kMul); c = _mm_add_ps(c, _mm_mul_ps(a, b1)); a = _mm_loadu_ps(A + i + kMul + 36); c = _mm_add_ps(c, _mm_mul_ps(a, b2)); a = _mm_loadu_ps(A + i + kMul + 72); c = _mm_add_ps(c, _mm_mul_ps(a, b3)); a = _mm_loadu_ps(A + i + kMul + 108); c = _mm_add_ps(c, _mm_mul_ps(a, b4)); _mm_storeu_ps(C + i + jMul, c); } } j++; jMul = j*36; for( k = 0; k < 32; k+=8) { kMul = k*36; b1 = _mm_load1_ps(B+ j + kMul); b2 = _mm_load1_ps(B+ j + kMul + 36); b3 = _mm_load1_ps(B+ j + kMul + 72); b4 = _mm_load1_ps(B+ j + kMul + 108); b5 = _mm_load1_ps(B + j + kMul + 144); b6 = _mm_load1_ps(B + j + kMul + 180); b7 = _mm_load1_ps(B + j + kMul + 216); b8 = _mm_load1_ps(B + j + kMul + 252); for( i = 0; i<36; i+=4){ c = _mm_loadu_ps(C + i + jMul); a = _mm_loadu_ps(A + i + kMul); c = _mm_add_ps(c, _mm_mul_ps(a, b1)); a = _mm_loadu_ps(A + i + kMul + 36); c = _mm_add_ps(c, _mm_mul_ps(a, b2)); a = _mm_loadu_ps(A + i + kMul + 72); c = _mm_add_ps(c, _mm_mul_ps(a, b3)); a = _mm_loadu_ps(A + i + kMul + 108); c = _mm_add_ps(c, _mm_mul_ps(a, b4)); a = _mm_loadu_ps(A + i + kMul + 144); c = _mm_add_ps(c, _mm_mul_ps(a, b5)); a = _mm_loadu_ps(A + i + kMul + 180); c = _mm_add_ps(c, _mm_mul_ps(a, b6)); a = _mm_loadu_ps(A + i + kMul + 216); c = _mm_add_ps(c, _mm_mul_ps(a, b7)); a = _mm_loadu_ps(A + i + kMul + 252); c = _mm_add_ps(c, _mm_mul_ps(a, b8)); _mm_storeu_ps(C + i + jMul, c); } } for( k = 32; k < 36; k+=4 ) { kMul = k*36; b1 = _mm_load1_ps(B+ j + kMul); b2 = _mm_load1_ps(B+ j + kMul + 36); b3 = _mm_load1_ps(B+ j + kMul + 72); b4 = _mm_load1_ps(B+ j + kMul + 108); for( i = 0; i < 36; i+=4 ) { c = _mm_loadu_ps(C + i + jMul); a = _mm_loadu_ps(A + i + kMul); c = _mm_add_ps(c, _mm_mul_ps(a, b1)); a = _mm_loadu_ps(A + i + kMul + 36); c = _mm_add_ps(c, _mm_mul_ps(a, b2)); a = _mm_loadu_ps(A + i + kMul + 72); c = _mm_add_ps(c, _mm_mul_ps(a, b3)); a = _mm_loadu_ps(A + i + kMul + 108); c = _mm_add_ps(c, _mm_mul_ps(a, b4)); _mm_storeu_ps(C + i + jMul, c); } } j++; jMul = j*36; for( k = 0; k < 32; k+=8) { kMul = k*36; b1 = _mm_load1_ps(B+ j + kMul); b2 = _mm_load1_ps(B+ j + kMul + 36); b3 = _mm_load1_ps(B+ j + kMul + 72); b4 = _mm_load1_ps(B+ j + kMul + 108); b5 = _mm_load1_ps(B + j + kMul + 144); b6 = _mm_load1_ps(B + j + kMul + 180); b7 = _mm_load1_ps(B + j + kMul + 216); b8 = _mm_load1_ps(B + j + kMul + 252); for( i = 0; i<36; i+=4){ c = _mm_loadu_ps(C + i + jMul); a = _mm_loadu_ps(A + i + kMul); c = _mm_add_ps(c, _mm_mul_ps(a, b1)); a = _mm_loadu_ps(A + i + kMul + 36); c = _mm_add_ps(c, _mm_mul_ps(a, b2)); a = _mm_loadu_ps(A + i + kMul + 72); c = _mm_add_ps(c, _mm_mul_ps(a, b3)); a = _mm_loadu_ps(A + i + kMul + 108); c = _mm_add_ps(c, _mm_mul_ps(a, b4)); a = _mm_loadu_ps(A + i + kMul + 144); c = _mm_add_ps(c, _mm_mul_ps(a, b5)); a = _mm_loadu_ps(A + i + kMul + 180); c = _mm_add_ps(c, _mm_mul_ps(a, b6)); a = _mm_loadu_ps(A + i + kMul + 216); c = _mm_add_ps(c, _mm_mul_ps(a, b7)); a = _mm_loadu_ps(A + i + kMul + 252); c = _mm_add_ps(c, _mm_mul_ps(a, b8)); _mm_storeu_ps(C + i + jMul, c); } } for( k = 32; k < 36; k+=4 ) { kMul = k*36; b1 = _mm_load1_ps(B+ j + kMul); b2 = _mm_load1_ps(B+ j + kMul + 36); b3 = _mm_load1_ps(B+ j + kMul + 72); b4 = _mm_load1_ps(B+ j + kMul + 108); for( i = 0; i < 36; i+=4 ) { c = _mm_loadu_ps(C + i + jMul); a = _mm_loadu_ps(A + i + kMul); c = _mm_add_ps(c, _mm_mul_ps(a, b1)); a = _mm_loadu_ps(A + i + kMul + 36); c = _mm_add_ps(c, _mm_mul_ps(a, b2)); a = _mm_loadu_ps(A + i + kMul + 72); c = _mm_add_ps(c, _mm_mul_ps(a, b3)); a = _mm_loadu_ps(A + i + kMul + 108); c = _mm_add_ps(c, _mm_mul_ps(a, b4)); _mm_storeu_ps(C + i + jMul, c); } } } }