diff --git a/src/spider/runtime/math/Matrix_Multiply.cpp b/src/spider/runtime/math/Matrix_Multiply.cpp new file mode 100644 index 0000000..cb965db --- /dev/null +++ b/src/spider/runtime/math/Matrix_Multiply.cpp @@ -0,0 +1,122 @@ +#include + + +template +struct Matrix { + T data[Rows][Cols]; + + void setZero() { + for (int i = 0; i < Rows; i++) { + for (int j = 0; j < Cols; j++) { + data[i][j] = T(); + } + } + } +}; + +template +Matrix mat_multiply(Matrix A, Matrix B) { + //Determination of resulting matrix size + int RSize; + if(M > P){ + RSize = M ; + } + else{ + RSize = P; + } + + //Condition for resulting matrix size rounding calculation + int CSize; + if (RSize <= 4){ + CSize = 4; + } + else if (RSize % 4 == 0){ + CSize = RSize; + } + else{ + CSize = 4 * (int(RSize/4) + 1); + } + + // SIMD width selection (based on your Python logic) + int simd_width; + if (RSize > 8) { + simd_width = 8; + } + else if (RSize > 4 && RSize < 8){ + simd_width = 4; + } + else { + simd_width = RSize; + } + + + // Create result square matrix with size CSize(optimization) + Matrix tempResult; + tempResult.setZero(); + + // Matrix multiplication with SIMD-style unrolling + for (int i = 0; i < M; i++) { // For each row in A + for (int k = 0; k < N; k++) { // For each inner dimension + T a_val = A.data[i][k]; + + // Process columns in chunks of simd_width + int j = 0; + while (j < P) { + int remaining = P - j; + + if (remaining >= simd_width) { + // Full SIMD operation - unrolled loops + if (simd_width == 8) { + tempResult.data[i][j] += a_val * B.data[k][j]; + tempResult.data[i][j+1] += a_val * B.data[k][j+1]; + tempResult.data[i][j+2] += a_val * B.data[k][j+2]; + tempResult.data[i][j+3] += a_val * B.data[k][j+3]; + tempResult.data[i][j+4] += a_val * B.data[k][j+4]; + tempResult.data[i][j+5] += a_val * B.data[k][j+5]; + tempResult.data[i][j+6] += a_val * B.data[k][j+6]; + tempResult.data[i][j+7] += a_val * B.data[k][j+7]; + } + else if (simd_width == 4) { + tempResult.data[i][j] += a_val * B.data[k][j]; + tempResult.data[i][j+1] += a_val * B.data[k][j+1]; + tempResult.data[i][j+2] += a_val * B.data[k][j+2]; + tempResult.data[i][j+3] += a_val * B.data[k][j+3]; + } + else if (simd_width == 3) { + tempResult.data[i][j] += a_val * B.data[k][j]; + tempResult.data[i][j+1] += a_val * B.data[k][j+1]; + tempResult.data[i][j+2] += a_val * B.data[k][j+2]; + } + else if (simd_width == 2) { + tempResult.data[i][j] += a_val * B.data[k][j]; + tempResult.data[i][j+1] += a_val * B.data[k][j+1]; + } + else { // simd_width == 1 + tempResult.data[i][j] += a_val * B.data[k][j]; + } + + j += simd_width; + } + else { + // Handle remaining columns that don't fit in SIMD width + for (int s = 0; s < remaining; s++) { + tempResult.data[i][j + s] += a_val * B.data[k][j + s]; + } + j += remaining; + } + } + } + } + + // Extract the actual result (M x P) from the temporary square matrix + Matrix result; + for (int i = 0; i < M; i++) { + for (int j = 0; j < P; j++) { + result.data[i][j] = tempResult.data[i][j]; + } + } + + return result; + + +} \ No newline at end of file