106 lines
4.6 KiB
C++
106 lines
4.6 KiB
C++
#include "Matrix.hpp"
|
|
|
|
#include <immintrin.h>
|
|
#include <type_traits>
|
|
#include <algorithm>
|
|
#include <cstring>
|
|
|
|
/*
|
|
namespace spider {
|
|
|
|
template<typename T>
|
|
void matrix_fill(T diag, Matrix<T> mat) {
|
|
for (isize i = 0; i < mat.rows; i++) {
|
|
for (isize j = 0; j < mat.cols; j++) {
|
|
m.data[i + j * mat.rows] = i == j ? diag : T(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
template<typename T>
|
|
void matrix_mult(Matrix<T> m1, Matrix<T> m2, Matrix<T> mr) {
|
|
// natural constrains of matrix multiplication
|
|
if (m1.rows != mr.rows) return;
|
|
if (m2.cols != mr.cols) return;
|
|
if (m1.cols != m2.rows) return;
|
|
|
|
// fill result with zeroes
|
|
std::fill(mr.data, mr.data + mr.rows * mr.cols, T(0));
|
|
|
|
// Begin Loop
|
|
for (isize j = 0; j < mr.cols; j++) { // P
|
|
for (isize n = 0; n < m1.cols; n++) { // N
|
|
const T val_m2 = m2.data[n + j * m2.rows] * diag;
|
|
isize i = 0;
|
|
|
|
#if defined(__AVX__)
|
|
if constexpr (std::is_same_v<T, float>) {
|
|
const __m256 v_m2 = _mm256_set1_ps(val_m2);
|
|
for (; i <= mr.rows - 8; i += 8) {
|
|
__m256 v_m1 = _mm256_loadu_ps(&m1.data[i + n * m1.rows]);
|
|
__m256 v_mr = _mm256_loadu_ps(&mr.data[i + j * mr.rows]);
|
|
v_mr = _mm256_fmadd_ps(v_m1, v_m2, v_mr);
|
|
_mm256_storeu_ps(&mr.data[i + j * mr.rows], v_mr);
|
|
}
|
|
if (i < mr.rows) {
|
|
float buf_m1[8] = { 0 }, buf_mr[8] = { 0 };
|
|
isize rem = mr.rows - i;
|
|
std::memcpy(buf_m1, &m1.data[i + n * m1.rows], rem * sizeof(T));
|
|
std::memcpy(buf_mr, &mr.data[i + j * mr.rows], rem * sizeof(T));
|
|
_mm256_storeu_ps(buf_mr, _mm256_fmadd_ps(_mm256_loadu_ps(buf_m1), v_m2, _mm256_loadu_ps(buf_mr)));
|
|
std::memcpy(&mr.data[i + j * mr.rows], buf_mr, rem * sizeof(T));
|
|
}
|
|
}
|
|
else if constexpr (std::is_same_v<T, double>) {
|
|
const __m256d v_m2 = _mm256_set1_pd(val_m2);
|
|
for (; i <= mr.rows - 4; i += 4) {
|
|
__m256d v_m1 = _mm256_loadu_pd(&m1.data[i + n * m1.rows]);
|
|
__m256d v_mr = _mm256_loadu_pd(&mr.data[i + j * mr.rows]);
|
|
v_mr = _mm256_fmadd_pd(v_m1, v_m2, v_mr);
|
|
_mm256_storeu_pd(&mr.data[i + j * mr.rows], v_mr);
|
|
}
|
|
if (i < mr.rows) {
|
|
double buf_m1[4] = { 0 }, buf_mr[4] = { 0 };
|
|
isize rem = mr.rows - i;
|
|
std::memcpy(buf_m1, &m1.data[i + n * m1.rows], rem * sizeof(T));
|
|
std::memcpy(buf_mr, &mr.data[i + j * mr.rows], rem * sizeof(T));
|
|
_mm256_storeu_pd(buf_mr, _mm256_fmadd_pd(_mm256_loadu_pd(buf_m1), v_m2, _mm256_loadu_pd(buf_mr)));
|
|
std::memcpy(&mr.data[i + j * mr.rows], buf_mr, rem * sizeof(T));
|
|
}
|
|
}
|
|
else
|
|
#elif defined(__SSE2__)
|
|
if constexpr (std::is_same_v<T, float>) {
|
|
const __m128 v_m2 = _mm_set1_ps(val_m2);
|
|
for (; i <= mr.rows - 4; i += 4) {
|
|
__m128 v_m1 = _mm_loadu_ps(&m1.data[i + n * m1.rows]);
|
|
__m128 v_mr = _mm_loadu_ps(&mr.data[i + j * mr.rows]);
|
|
v_mr = _mm_add_ps(v_mr, _mm_mul_ps(v_m1, v_m2));
|
|
_mm_storeu_ps(&mr.data[i + j * mr.rows], v_mr);
|
|
}
|
|
// Tail buffer logic omitted for brevity, same as float AVX but with size 4
|
|
}
|
|
else if constexpr (std::is_same_v<T, double>) {
|
|
const __m128d v_m2 = _mm_set1_pd(val_m2);
|
|
for (; i <= mr.rows - 2; i += 2) {
|
|
__m128d v_m1 = _mm_loadu_pd(&m1.data[i + n * m1.rows]);
|
|
__m128d v_mr = _mm_loadu_pd(&mr.data[i + j * mr.rows]);
|
|
v_mr = _mm_add_pd(v_mr, _mm_mul_pd(v_m1, v_m2));
|
|
_mm_storeu_pd(&mr.data[i + j * mr.rows], v_mr);
|
|
}
|
|
}
|
|
else
|
|
#endif
|
|
{
|
|
// Fallback for non-SIMD or unsupported types
|
|
for (; i < mr.rows; i++) {
|
|
mr.data[i + j * mr.rows] += m1.data[i + n * m1.rows] * val_m2;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
*/
|