code chnages

This commit is contained in:
2026-03-27 16:34:23 -06:00
parent 8697b44f53
commit a5ffb69565
2 changed files with 383 additions and 21 deletions

View File

@@ -17,27 +17,15 @@ struct Matrix {
template<typename T, int M, int N, int P> template<typename T, int M, int N, int P>
Matrix<T, M, P> mat_multiply(Matrix<T, M, N> A, Matrix<T, N, P> B) { Matrix<T, M, P> mat_multiply(Matrix<T, M, N> A, Matrix<T, N, P> B) {
//Determination of resulting matrix size //Determination of resulting matrix size
int RSize; // COMPILE-TIME calculation (using constexpr)
if(M > P){ constexpr int RSize = (M > P) ? M : P;
RSize = M ;
}
else{
RSize = P;
}
//Condition for resulting matrix size rounding calculation // COMPILE-TIME square size calculation
int CSize; constexpr int CSize = (RSize <= 4) ? 4 :
if (RSize <= 4){ (RSize % 4 == 0) ? RSize :
CSize = 4; 4 * (RSize / 4 + 1);
}
else if (RSize % 4 == 0){
CSize = RSize;
}
else{
CSize = 4 * (int(RSize/4) + 1);
}
// SIMD width selection (based on your Python logic) // SIMD width selection
int simd_width; int simd_width;
if (RSize > 8) { if (RSize > 8) {
simd_width = 8; simd_width = 8;
@@ -117,6 +105,380 @@ Matrix<T, M, P> mat_multiply(Matrix<T, M, N> A, Matrix<T, N, P> B) {
} }
return result; return result;
}
// Helper function to print matrices for testing
void printMatrix(const Matrix<double, 3, 3>& m) {
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
std::cout << m.data[i][j] << " ";
}
std::cout << std::endl;
}
}
void printMatrix(const Matrix<double, 2, 2>& m) {
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
std::cout << m.data[i][j] << " ";
}
std::cout << std::endl;
}
}
void printMatrix(const Matrix<double, 3, 2>& m) {
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 2; j++) {
std::cout << m.data[i][j] << " ";
}
std::cout << std::endl;
}
}
void printMatrix(const Matrix<double, 1, 1>& m) {
std::cout << m.data[0][0] << std::endl;
}
int main() {
std::cout << "========================================" << std::endl;
std::cout << " MATRIX MULTIPLICATION TEST SUITE" << std::endl;
std::cout << "========================================" << std::endl;
int tests_passed = 0;
int total_tests = 6;
// ==================== TEST 1: Simple 2x2 ====================
std::cout << "\n========================================" << std::endl;
std::cout << "TEST 1: 2x2 Matrix Multiplication" << std::endl;
std::cout << "========================================" << std::endl;
{
Matrix<double, 2, 2> A;
A.data[0][0] = 1; A.data[0][1] = 2;
A.data[1][0] = 3; A.data[1][1] = 4;
Matrix<double, 2, 2> B;
B.data[0][0] = 5; B.data[0][1] = 6;
B.data[1][0] = 7; B.data[1][1] = 8;
Matrix<double, 2, 2> expected;
expected.data[0][0] = 19; expected.data[0][1] = 22;
expected.data[1][0] = 43; expected.data[1][1] = 50;
auto result = mat_multiply(A, B);
std::cout << "A:" << std::endl;
printMatrix(A);
std::cout << "B:" << std::endl;
printMatrix(B);
std::cout << "Expected:" << std::endl;
printMatrix(expected);
std::cout << "Result:" << std::endl;
printMatrix(result);
bool passed = true;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
if (std::abs(result.data[i][j] - expected.data[i][j]) > 1e-6) {
passed = false;
}
}
}
if (passed) {
std::cout << "✓ PASSED" << std::endl;
tests_passed++;
} else {
std::cout << "✗ FAILED" << std::endl;
}
}
// ==================== TEST 2: Non-Square Matrices ====================
std::cout << "\n========================================" << std::endl;
std::cout << "TEST 2: Non-Square Matrix Multiplication (3x4 * 4x2)" << std::endl;
std::cout << "========================================" << std::endl;
{
Matrix<double, 3, 4> A;
int val = 1;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 4; j++) {
A.data[i][j] = val++;
}
}
Matrix<double, 4, 2> B;
val = 1;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 2; j++) {
B.data[i][j] = val++;
}
}
Matrix<double, 3, 2> expected;
expected.data[0][0] = 50; expected.data[0][1] = 60;
expected.data[1][0] = 114; expected.data[1][1] = 140;
expected.data[2][0] = 178; expected.data[2][1] = 220;
auto result = mat_multiply(A, B);
std::cout << "A (3x4):" << std::endl;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 4; j++) {
std::cout << A.data[i][j] << " ";
}
std::cout << std::endl;
}
std::cout << "B (4x2):" << std::endl;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 2; j++) {
std::cout << B.data[i][j] << " ";
}
std::cout << std::endl;
}
std::cout << "Expected (3x2):" << std::endl;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 2; j++) {
std::cout << expected.data[i][j] << " ";
}
std::cout << std::endl;
}
std::cout << "Result:" << std::endl;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 2; j++) {
std::cout << result.data[i][j] << " ";
}
std::cout << std::endl;
}
bool passed = true;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 2; j++) {
if (std::abs(result.data[i][j] - expected.data[i][j]) > 1e-6) {
passed = false;
}
}
}
if (passed) {
std::cout << "✓ PASSED" << std::endl;
tests_passed++;
} else {
std::cout << "✗ FAILED" << std::endl;
}
}
// ==================== TEST 3: Identity Matrix ====================
std::cout << "\n========================================" << std::endl;
std::cout << "TEST 3: Identity Matrix Multiplication" << std::endl;
std::cout << "========================================" << std::endl;
{
Matrix<double, 3, 3> A;
int val = 1;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
A.data[i][j] = val++;
}
}
Matrix<double, 3, 3> I;
// Initialize identity matrix
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
I.data[i][j] = (i == j) ? 1.0 : 0.0;
}
}
auto result = mat_multiply(A, I);
std::cout << "A * I should equal A" << std::endl;
std::cout << "A:" << std::endl;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
std::cout << A.data[i][j] << " ";
}
std::cout << std::endl;
}
std::cout << "Result:" << std::endl;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
std::cout << result.data[i][j] << " ";
}
std::cout << std::endl;
}
bool passed = true;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
if (std::abs(result.data[i][j] - A.data[i][j]) > 1e-6) {
passed = false;
}
}
}
if (passed) {
std::cout << "✓ PASSED" << std::endl;
tests_passed++;
} else {
std::cout << "✗ FAILED" << std::endl;
}
}
// ==================== TEST 4: Negative Numbers ====================
std::cout << "\n========================================" << std::endl;
std::cout << "TEST 4: Matrix with Negative Numbers" << std::endl;
std::cout << "========================================" << std::endl;
{
Matrix<double, 2, 2> A;
A.data[0][0] = 1; A.data[0][1] = -2;
A.data[1][0] = -3; A.data[1][1] = 4;
Matrix<double, 2, 2> B;
B.data[0][0] = -5; B.data[0][1] = 6;
B.data[1][0] = 7; B.data[1][1] = -8;
Matrix<double, 2, 2> expected;
expected.data[0][0] = -19; expected.data[0][1] = 22;
expected.data[1][0] = 43; expected.data[1][1] = -50;
auto result = mat_multiply(A, B);
std::cout << "A:" << std::endl;
printMatrix(A);
std::cout << "B:" << std::endl;
printMatrix(B);
std::cout << "Expected:" << std::endl;
printMatrix(expected);
std::cout << "Result:" << std::endl;
printMatrix(result);
bool passed = true;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
if (std::abs(result.data[i][j] - expected.data[i][j]) > 1e-6) {
passed = false;
}
}
}
if (passed) {
std::cout << "✓ PASSED" << std::endl;
tests_passed++;
} else {
std::cout << "✗ FAILED" << std::endl;
}
}
// ==================== TEST 5: Vector Dot Product ====================
std::cout << "\n========================================" << std::endl;
std::cout << "TEST 5: Vector Dot Product (1x3 * 3x1)" << std::endl;
std::cout << "========================================" << std::endl;
{
Matrix<double, 1, 3> rowVector;
rowVector.data[0][0] = 1;
rowVector.data[0][1] = 2;
rowVector.data[0][2] = 3;
Matrix<double, 3, 1> colVector;
colVector.data[0][0] = 4;
colVector.data[1][0] = 5;
colVector.data[2][0] = 6;
auto result = mat_multiply(rowVector, colVector);
std::cout << "Row vector (1x3): ";
std::cout << rowVector.data[0][0] << " " << rowVector.data[0][1] << " " << rowVector.data[0][2] << std::endl;
std::cout << "Column vector (3x1):" << std::endl;
std::cout << colVector.data[0][0] << std::endl;
std::cout << colVector.data[1][0] << std::endl;
std::cout << colVector.data[2][0] << std::endl;
std::cout << "Result (should be 32): " << result.data[0][0] << std::endl;
bool passed = std::abs(result.data[0][0] - 32) < 1e-6;
if (passed) {
std::cout << "✓ PASSED" << std::endl;
tests_passed++;
} else {
std::cout << "✗ FAILED" << std::endl;
}
}
// ==================== TEST 6: Associativity ====================
std::cout << "\n========================================" << std::endl;
std::cout << "TEST 6: Associativity (A*B)*C == A*(B*C)" << std::endl;
std::cout << "========================================" << std::endl;
{
Matrix<double, 2, 3> A;
int val = 1;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 3; j++) {
A.data[i][j] = val++;
}
}
Matrix<double, 3, 4> B;
val = 1;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 4; j++) {
B.data[i][j] = val++;
}
}
Matrix<double, 4, 2> C;
val = 1;
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 2; j++) {
C.data[i][j] = val++;
}
}
auto AB = mat_multiply(A, B);
auto AB_C = mat_multiply(AB, C);
auto BC = mat_multiply(B, C);
auto A_BC = mat_multiply(A, BC);
std::cout << "Result of (A*B)*C:" << std::endl;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
std::cout << AB_C.data[i][j] << " ";
}
std::cout << std::endl;
}
std::cout << "Result of A*(B*C):" << std::endl;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
std::cout << A_BC.data[i][j] << " ";
}
std::cout << std::endl;
}
bool passed = true;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
if (std::abs(AB_C.data[i][j] - A_BC.data[i][j]) > 1e-6) {
passed = false;
}
}
}
if (passed) {
std::cout << "✓ PASSED (A*B)*C equals A*(B*C)" << std::endl;
tests_passed++;
} else {
std::cout << "✗ FAILED" << std::endl;
}
}
// ==================== FINAL RESULTS ====================
std::cout << "\n========================================" << std::endl;
std::cout << "RESULTS: " << tests_passed << " / " << total_tests << " tests passed" << std::endl;
std::cout << "========================================" << std::endl;
if (tests_passed == total_tests) {
std::cout << "🎉 All tests passed! Your matrix multiplication is correct! 🎉" << std::endl;
} else {
std::cout << "⚠️ " << (total_tests - tests_passed) << " test(s) failed. Please check your implementation." << std::endl;
}
return 0;
} }

Binary file not shown.