Compare commits
4 Commits
c6c63d6391
...
2f38eb38b8
| Author | SHA1 | Date | |
|---|---|---|---|
| 2f38eb38b8 | |||
| 5f342991b0 | |||
| a5ffb69565 | |||
| 8697b44f53 |
484
src/spider/runtime/math/Matrix_Multiply.cpp
Normal file
484
src/spider/runtime/math/Matrix_Multiply.cpp
Normal file
@@ -0,0 +1,484 @@
|
||||
#include <iostream>
|
||||
|
||||
|
||||
template<typename T, int Rows, int Cols>
|
||||
struct Matrix {
|
||||
T data[Rows][Cols];
|
||||
|
||||
void setZero() {
|
||||
for (int i = 0; i < Rows; i++) {
|
||||
for (int j = 0; j < Cols; j++) {
|
||||
data[i][j] = T();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, int M, int N, int P>
|
||||
Matrix<T, M, P> mat_multiply(Matrix<T, M, N> A, Matrix<T, N, P> B) {
|
||||
//Determination of resulting matrix size
|
||||
// COMPILE-TIME calculation (using constexpr)
|
||||
constexpr int RSize = (M > P) ? M : P;
|
||||
|
||||
// COMPILE-TIME square size calculation
|
||||
constexpr int CSize = (RSize <= 4) ? 4 :
|
||||
(RSize % 4 == 0) ? RSize :
|
||||
4 * (RSize / 4 + 1);
|
||||
|
||||
// SIMD width selection
|
||||
int simd_width;
|
||||
if (RSize > 8) {
|
||||
simd_width = 8;
|
||||
}
|
||||
else if (RSize > 4 && RSize < 8){
|
||||
simd_width = 4;
|
||||
}
|
||||
else {
|
||||
simd_width = RSize;
|
||||
}
|
||||
|
||||
|
||||
// Create result square matrix with size CSize(optimization)
|
||||
Matrix<T, CSize, CSize> tempResult;
|
||||
tempResult.setZero();
|
||||
|
||||
// Matrix multiplication with SIMD-style unrolling
|
||||
for (int i = 0; i < M; i++) { // For each row in A
|
||||
for (int k = 0; k < N; k++) { // For each inner dimension
|
||||
T a_val = A.data[i][k];
|
||||
|
||||
// Process columns in chunks of simd_width
|
||||
int j = 0;
|
||||
while (j < P) {
|
||||
int remaining = P - j;
|
||||
|
||||
if (remaining >= simd_width) {
|
||||
// Full SIMD operation - unrolled loops
|
||||
if (simd_width == 8) {
|
||||
tempResult.data[i][j] += a_val * B.data[k][j];
|
||||
tempResult.data[i][j+1] += a_val * B.data[k][j+1];
|
||||
tempResult.data[i][j+2] += a_val * B.data[k][j+2];
|
||||
tempResult.data[i][j+3] += a_val * B.data[k][j+3];
|
||||
tempResult.data[i][j+4] += a_val * B.data[k][j+4];
|
||||
tempResult.data[i][j+5] += a_val * B.data[k][j+5];
|
||||
tempResult.data[i][j+6] += a_val * B.data[k][j+6];
|
||||
tempResult.data[i][j+7] += a_val * B.data[k][j+7];
|
||||
}
|
||||
else if (simd_width == 4) {
|
||||
tempResult.data[i][j] += a_val * B.data[k][j];
|
||||
tempResult.data[i][j+1] += a_val * B.data[k][j+1];
|
||||
tempResult.data[i][j+2] += a_val * B.data[k][j+2];
|
||||
tempResult.data[i][j+3] += a_val * B.data[k][j+3];
|
||||
}
|
||||
else if (simd_width == 3) {
|
||||
tempResult.data[i][j] += a_val * B.data[k][j];
|
||||
tempResult.data[i][j+1] += a_val * B.data[k][j+1];
|
||||
tempResult.data[i][j+2] += a_val * B.data[k][j+2];
|
||||
}
|
||||
else if (simd_width == 2) {
|
||||
tempResult.data[i][j] += a_val * B.data[k][j];
|
||||
tempResult.data[i][j+1] += a_val * B.data[k][j+1];
|
||||
}
|
||||
else { // simd_width == 1
|
||||
tempResult.data[i][j] += a_val * B.data[k][j];
|
||||
}
|
||||
|
||||
j += simd_width;
|
||||
}
|
||||
else {
|
||||
// Handle remaining columns that don't fit in SIMD width
|
||||
for (int s = 0; s < remaining; s++) {
|
||||
tempResult.data[i][j + s] += a_val * B.data[k][j + s];
|
||||
}
|
||||
j += remaining;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the actual result (M x P) from the temporary square matrix
|
||||
Matrix<T, M, P> result;
|
||||
for (int i = 0; i < M; i++) {
|
||||
for (int j = 0; j < P; j++) {
|
||||
result.data[i][j] = tempResult.data[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Helper function to print matrices for testing
|
||||
void printMatrix(const Matrix<double, 3, 3>& m) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
std::cout << m.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void printMatrix(const Matrix<double, 2, 2>& m) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
std::cout << m.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void printMatrix(const Matrix<double, 3, 2>& m) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
std::cout << m.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void printMatrix(const Matrix<double, 1, 1>& m) {
|
||||
std::cout << m.data[0][0] << std::endl;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::cout << "========================================" << std::endl;
|
||||
std::cout << " MATRIX MULTIPLICATION TEST SUITE" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
|
||||
int tests_passed = 0;
|
||||
int total_tests = 6;
|
||||
|
||||
// ==================== TEST 1: Simple 2x2 ====================
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "TEST 1: 2x2 Matrix Multiplication" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
{
|
||||
Matrix<double, 2, 2> A;
|
||||
A.data[0][0] = 1; A.data[0][1] = 2;
|
||||
A.data[1][0] = 3; A.data[1][1] = 4;
|
||||
|
||||
Matrix<double, 2, 2> B;
|
||||
B.data[0][0] = 5; B.data[0][1] = 6;
|
||||
B.data[1][0] = 7; B.data[1][1] = 8;
|
||||
|
||||
Matrix<double, 2, 2> expected;
|
||||
expected.data[0][0] = 19; expected.data[0][1] = 22;
|
||||
expected.data[1][0] = 43; expected.data[1][1] = 50;
|
||||
|
||||
auto result = mat_multiply(A, B);
|
||||
|
||||
std::cout << "A:" << std::endl;
|
||||
printMatrix(A);
|
||||
std::cout << "B:" << std::endl;
|
||||
printMatrix(B);
|
||||
std::cout << "Expected:" << std::endl;
|
||||
printMatrix(expected);
|
||||
std::cout << "Result:" << std::endl;
|
||||
printMatrix(result);
|
||||
|
||||
bool passed = true;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
if (std::abs(result.data[i][j] - expected.data[i][j]) > 1e-6) {
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (passed) {
|
||||
std::cout << "✓ PASSED" << std::endl;
|
||||
tests_passed++;
|
||||
} else {
|
||||
std::cout << "✗ FAILED" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== TEST 2: Non-Square Matrices ====================
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "TEST 2: Non-Square Matrix Multiplication (3x4 * 4x2)" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
{
|
||||
Matrix<double, 3, 4> A;
|
||||
int val = 1;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 4; j++) {
|
||||
A.data[i][j] = val++;
|
||||
}
|
||||
}
|
||||
|
||||
Matrix<double, 4, 2> B;
|
||||
val = 1;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
B.data[i][j] = val++;
|
||||
}
|
||||
}
|
||||
|
||||
Matrix<double, 3, 2> expected;
|
||||
expected.data[0][0] = 50; expected.data[0][1] = 60;
|
||||
expected.data[1][0] = 114; expected.data[1][1] = 140;
|
||||
expected.data[2][0] = 178; expected.data[2][1] = 220;
|
||||
|
||||
auto result = mat_multiply(A, B);
|
||||
|
||||
std::cout << "A (3x4):" << std::endl;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 4; j++) {
|
||||
std::cout << A.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "B (4x2):" << std::endl;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
std::cout << B.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "Expected (3x2):" << std::endl;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
std::cout << expected.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "Result:" << std::endl;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
std::cout << result.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
bool passed = true;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
if (std::abs(result.data[i][j] - expected.data[i][j]) > 1e-6) {
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (passed) {
|
||||
std::cout << "✓ PASSED" << std::endl;
|
||||
tests_passed++;
|
||||
} else {
|
||||
std::cout << "✗ FAILED" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== TEST 3: Identity Matrix ====================
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "TEST 3: Identity Matrix Multiplication" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
{
|
||||
Matrix<double, 3, 3> A;
|
||||
int val = 1;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
A.data[i][j] = val++;
|
||||
}
|
||||
}
|
||||
|
||||
Matrix<double, 3, 3> I;
|
||||
// Initialize identity matrix
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
I.data[i][j] = (i == j) ? 1.0 : 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
auto result = mat_multiply(A, I);
|
||||
|
||||
std::cout << "A * I should equal A" << std::endl;
|
||||
std::cout << "A:" << std::endl;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
std::cout << A.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "Result:" << std::endl;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
std::cout << result.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
bool passed = true;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
if (std::abs(result.data[i][j] - A.data[i][j]) > 1e-6) {
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (passed) {
|
||||
std::cout << "✓ PASSED" << std::endl;
|
||||
tests_passed++;
|
||||
} else {
|
||||
std::cout << "✗ FAILED" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== TEST 4: Negative Numbers ====================
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "TEST 4: Matrix with Negative Numbers" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
{
|
||||
Matrix<double, 2, 2> A;
|
||||
A.data[0][0] = 1; A.data[0][1] = -2;
|
||||
A.data[1][0] = -3; A.data[1][1] = 4;
|
||||
|
||||
Matrix<double, 2, 2> B;
|
||||
B.data[0][0] = -5; B.data[0][1] = 6;
|
||||
B.data[1][0] = 7; B.data[1][1] = -8;
|
||||
|
||||
Matrix<double, 2, 2> expected;
|
||||
expected.data[0][0] = -19; expected.data[0][1] = 22;
|
||||
expected.data[1][0] = 43; expected.data[1][1] = -50;
|
||||
|
||||
auto result = mat_multiply(A, B);
|
||||
|
||||
std::cout << "A:" << std::endl;
|
||||
printMatrix(A);
|
||||
std::cout << "B:" << std::endl;
|
||||
printMatrix(B);
|
||||
std::cout << "Expected:" << std::endl;
|
||||
printMatrix(expected);
|
||||
std::cout << "Result:" << std::endl;
|
||||
printMatrix(result);
|
||||
|
||||
bool passed = true;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
if (std::abs(result.data[i][j] - expected.data[i][j]) > 1e-6) {
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (passed) {
|
||||
std::cout << "✓ PASSED" << std::endl;
|
||||
tests_passed++;
|
||||
} else {
|
||||
std::cout << "✗ FAILED" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== TEST 5: Vector Dot Product ====================
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "TEST 5: Vector Dot Product (1x3 * 3x1)" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
{
|
||||
Matrix<double, 1, 3> rowVector;
|
||||
rowVector.data[0][0] = 1;
|
||||
rowVector.data[0][1] = 2;
|
||||
rowVector.data[0][2] = 3;
|
||||
|
||||
Matrix<double, 3, 1> colVector;
|
||||
colVector.data[0][0] = 4;
|
||||
colVector.data[1][0] = 5;
|
||||
colVector.data[2][0] = 6;
|
||||
|
||||
auto result = mat_multiply(rowVector, colVector);
|
||||
|
||||
std::cout << "Row vector (1x3): ";
|
||||
std::cout << rowVector.data[0][0] << " " << rowVector.data[0][1] << " " << rowVector.data[0][2] << std::endl;
|
||||
std::cout << "Column vector (3x1):" << std::endl;
|
||||
std::cout << colVector.data[0][0] << std::endl;
|
||||
std::cout << colVector.data[1][0] << std::endl;
|
||||
std::cout << colVector.data[2][0] << std::endl;
|
||||
std::cout << "Result (should be 32): " << result.data[0][0] << std::endl;
|
||||
|
||||
bool passed = std::abs(result.data[0][0] - 32) < 1e-6;
|
||||
|
||||
if (passed) {
|
||||
std::cout << "✓ PASSED" << std::endl;
|
||||
tests_passed++;
|
||||
} else {
|
||||
std::cout << "✗ FAILED" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== TEST 6: Associativity ====================
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "TEST 6: Associativity (A*B)*C == A*(B*C)" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
{
|
||||
Matrix<double, 2, 3> A;
|
||||
int val = 1;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
A.data[i][j] = val++;
|
||||
}
|
||||
}
|
||||
|
||||
Matrix<double, 3, 4> B;
|
||||
val = 1;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int j = 0; j < 4; j++) {
|
||||
B.data[i][j] = val++;
|
||||
}
|
||||
}
|
||||
|
||||
Matrix<double, 4, 2> C;
|
||||
val = 1;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
C.data[i][j] = val++;
|
||||
}
|
||||
}
|
||||
|
||||
auto AB = mat_multiply(A, B);
|
||||
auto AB_C = mat_multiply(AB, C);
|
||||
|
||||
auto BC = mat_multiply(B, C);
|
||||
auto A_BC = mat_multiply(A, BC);
|
||||
|
||||
std::cout << "Result of (A*B)*C:" << std::endl;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
std::cout << AB_C.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "Result of A*(B*C):" << std::endl;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
std::cout << A_BC.data[i][j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
bool passed = true;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
if (std::abs(AB_C.data[i][j] - A_BC.data[i][j]) > 1e-6) {
|
||||
passed = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (passed) {
|
||||
std::cout << "✓ PASSED (A*B)*C equals A*(B*C)" << std::endl;
|
||||
tests_passed++;
|
||||
} else {
|
||||
std::cout << "✗ FAILED" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== FINAL RESULTS ====================
|
||||
std::cout << "\n========================================" << std::endl;
|
||||
std::cout << "RESULTS: " << tests_passed << " / " << total_tests << " tests passed" << std::endl;
|
||||
std::cout << "========================================" << std::endl;
|
||||
|
||||
if (tests_passed == total_tests) {
|
||||
std::cout << "🎉 All tests passed! Your matrix multiplication is correct! 🎉" << std::endl;
|
||||
} else {
|
||||
std::cout << "⚠️ " << (total_tests - tests_passed) << " test(s) failed. Please check your implementation." << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user