#include "Matrix.hpp" #include #include #include #include /* namespace spider { template void matrix_fill(T diag, Matrix mat) { for (isize i = 0; i < mat.rows; i++) { for (isize j = 0; j < mat.cols; j++) { m.data[i + j * mat.rows] = i == j ? diag : T(0); } } } template void matrix_mult(Matrix m1, Matrix m2, Matrix mr) { // natural constrains of matrix multiplication if (m1.rows != mr.rows) return; if (m2.cols != mr.cols) return; if (m1.cols != m2.rows) return; // fill result with zeroes std::fill(mr.data, mr.data + mr.rows * mr.cols, T(0)); // Begin Loop for (isize j = 0; j < mr.cols; j++) { // P for (isize n = 0; n < m1.cols; n++) { // N const T val_m2 = m2.data[n + j * m2.rows] * diag; isize i = 0; #if defined(__AVX__) if constexpr (std::is_same_v) { const __m256 v_m2 = _mm256_set1_ps(val_m2); for (; i <= mr.rows - 8; i += 8) { __m256 v_m1 = _mm256_loadu_ps(&m1.data[i + n * m1.rows]); __m256 v_mr = _mm256_loadu_ps(&mr.data[i + j * mr.rows]); v_mr = _mm256_fmadd_ps(v_m1, v_m2, v_mr); _mm256_storeu_ps(&mr.data[i + j * mr.rows], v_mr); } if (i < mr.rows) { float buf_m1[8] = { 0 }, buf_mr[8] = { 0 }; isize rem = mr.rows - i; std::memcpy(buf_m1, &m1.data[i + n * m1.rows], rem * sizeof(T)); std::memcpy(buf_mr, &mr.data[i + j * mr.rows], rem * sizeof(T)); _mm256_storeu_ps(buf_mr, _mm256_fmadd_ps(_mm256_loadu_ps(buf_m1), v_m2, _mm256_loadu_ps(buf_mr))); std::memcpy(&mr.data[i + j * mr.rows], buf_mr, rem * sizeof(T)); } } else if constexpr (std::is_same_v) { const __m256d v_m2 = _mm256_set1_pd(val_m2); for (; i <= mr.rows - 4; i += 4) { __m256d v_m1 = _mm256_loadu_pd(&m1.data[i + n * m1.rows]); __m256d v_mr = _mm256_loadu_pd(&mr.data[i + j * mr.rows]); v_mr = _mm256_fmadd_pd(v_m1, v_m2, v_mr); _mm256_storeu_pd(&mr.data[i + j * mr.rows], v_mr); } if (i < mr.rows) { double buf_m1[4] = { 0 }, buf_mr[4] = { 0 }; isize rem = mr.rows - i; std::memcpy(buf_m1, &m1.data[i + n * m1.rows], rem * sizeof(T)); std::memcpy(buf_mr, &mr.data[i + j * mr.rows], rem * sizeof(T)); _mm256_storeu_pd(buf_mr, _mm256_fmadd_pd(_mm256_loadu_pd(buf_m1), v_m2, _mm256_loadu_pd(buf_mr))); std::memcpy(&mr.data[i + j * mr.rows], buf_mr, rem * sizeof(T)); } } else #elif defined(__SSE2__) if constexpr (std::is_same_v) { const __m128 v_m2 = _mm_set1_ps(val_m2); for (; i <= mr.rows - 4; i += 4) { __m128 v_m1 = _mm_loadu_ps(&m1.data[i + n * m1.rows]); __m128 v_mr = _mm_loadu_ps(&mr.data[i + j * mr.rows]); v_mr = _mm_add_ps(v_mr, _mm_mul_ps(v_m1, v_m2)); _mm_storeu_ps(&mr.data[i + j * mr.rows], v_mr); } // Tail buffer logic omitted for brevity, same as float AVX but with size 4 } else if constexpr (std::is_same_v) { const __m128d v_m2 = _mm_set1_pd(val_m2); for (; i <= mr.rows - 2; i += 2) { __m128d v_m1 = _mm_loadu_pd(&m1.data[i + n * m1.rows]); __m128d v_mr = _mm_loadu_pd(&mr.data[i + j * mr.rows]); v_mr = _mm_add_pd(v_mr, _mm_mul_pd(v_m1, v_m2)); _mm_storeu_pd(&mr.data[i + j * mr.rows], v_mr); } } else #endif { // Fallback for non-SIMD or unsupported types for (; i < mr.rows; i++) { mr.data[i + j * mr.rows] += m1.data[i + n * m1.rows] * val_m2; } } } } } } */