I'm still working on this
This commit is contained in:
103
src/spider/runtime/math/Matrix.cpp
Normal file
103
src/spider/runtime/math/Matrix.cpp
Normal file
@@ -0,0 +1,103 @@
|
||||
#include "Matrix.hpp"
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <type_traits>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
namespace spider {
|
||||
|
||||
template<typename T>
|
||||
void matrix_fill(T diag, Matrix<T> mat) {
|
||||
for (isize i = 0; i < mat.rows; i++) {
|
||||
for (isize j = 0; j < mat.cols; j++) {
|
||||
m.data[i + j * mat.rows] = i == j ? diag : T(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void matrix_mult(Matrix<T> m1, Matrix<T> m2, Matrix<T> mr) {
|
||||
// natural constrains of matrix multiplication
|
||||
if (m1.rows != mr.rows) return;
|
||||
if (m2.cols != mr.cols) return;
|
||||
if (m1.cols != m2.rows) return;
|
||||
|
||||
// fill result with zeroes
|
||||
std::fill(mr.data, mr.data + mr.rows * mr.cols, T(0));
|
||||
|
||||
// Begin Loop
|
||||
for (isize j = 0; j < mr.cols; j++) { // P
|
||||
for (isize n = 0; n < m1.cols; n++) { // N
|
||||
const T val_m2 = m2.data[n + j * m2.rows] * diag;
|
||||
isize i = 0;
|
||||
|
||||
#if defined(__AVX__)
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const __m256 v_m2 = _mm256_set1_ps(val_m2);
|
||||
for (; i <= mr.rows - 8; i += 8) {
|
||||
__m256 v_m1 = _mm256_loadu_ps(&m1.data[i + n * m1.rows]);
|
||||
__m256 v_mr = _mm256_loadu_ps(&mr.data[i + j * mr.rows]);
|
||||
v_mr = _mm256_fmadd_ps(v_m1, v_m2, v_mr);
|
||||
_mm256_storeu_ps(&mr.data[i + j * mr.rows], v_mr);
|
||||
}
|
||||
if (i < mr.rows) {
|
||||
float buf_m1[8] = { 0 }, buf_mr[8] = { 0 };
|
||||
isize rem = mr.rows - i;
|
||||
std::memcpy(buf_m1, &m1.data[i + n * m1.rows], rem * sizeof(T));
|
||||
std::memcpy(buf_mr, &mr.data[i + j * mr.rows], rem * sizeof(T));
|
||||
_mm256_storeu_ps(buf_mr, _mm256_fmadd_ps(_mm256_loadu_ps(buf_m1), v_m2, _mm256_loadu_ps(buf_mr)));
|
||||
std::memcpy(&mr.data[i + j * mr.rows], buf_mr, rem * sizeof(T));
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, double>) {
|
||||
const __m256d v_m2 = _mm256_set1_pd(val_m2);
|
||||
for (; i <= mr.rows - 4; i += 4) {
|
||||
__m256d v_m1 = _mm256_loadu_pd(&m1.data[i + n * m1.rows]);
|
||||
__m256d v_mr = _mm256_loadu_pd(&mr.data[i + j * mr.rows]);
|
||||
v_mr = _mm256_fmadd_pd(v_m1, v_m2, v_mr);
|
||||
_mm256_storeu_pd(&mr.data[i + j * mr.rows], v_mr);
|
||||
}
|
||||
if (i < mr.rows) {
|
||||
double buf_m1[4] = { 0 }, buf_mr[4] = { 0 };
|
||||
isize rem = mr.rows - i;
|
||||
std::memcpy(buf_m1, &m1.data[i + n * m1.rows], rem * sizeof(T));
|
||||
std::memcpy(buf_mr, &mr.data[i + j * mr.rows], rem * sizeof(T));
|
||||
_mm256_storeu_pd(buf_mr, _mm256_fmadd_pd(_mm256_loadu_pd(buf_m1), v_m2, _mm256_loadu_pd(buf_mr)));
|
||||
std::memcpy(&mr.data[i + j * mr.rows], buf_mr, rem * sizeof(T));
|
||||
}
|
||||
}
|
||||
else
|
||||
#elif defined(__SSE2__)
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const __m128 v_m2 = _mm_set1_ps(val_m2);
|
||||
for (; i <= mr.rows - 4; i += 4) {
|
||||
__m128 v_m1 = _mm_loadu_ps(&m1.data[i + n * m1.rows]);
|
||||
__m128 v_mr = _mm_loadu_ps(&mr.data[i + j * mr.rows]);
|
||||
v_mr = _mm_add_ps(v_mr, _mm_mul_ps(v_m1, v_m2));
|
||||
_mm_storeu_ps(&mr.data[i + j * mr.rows], v_mr);
|
||||
}
|
||||
// Tail buffer logic omitted for brevity, same as float AVX but with size 4
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, double>) {
|
||||
const __m128d v_m2 = _mm_set1_pd(val_m2);
|
||||
for (; i <= mr.rows - 2; i += 2) {
|
||||
__m128d v_m1 = _mm_loadu_pd(&m1.data[i + n * m1.rows]);
|
||||
__m128d v_mr = _mm_loadu_pd(&mr.data[i + j * mr.rows]);
|
||||
v_mr = _mm_add_pd(v_mr, _mm_mul_pd(v_m1, v_m2));
|
||||
_mm_storeu_pd(&mr.data[i + j * mr.rows], v_mr);
|
||||
}
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// Fallback for non-SIMD or unsupported types
|
||||
for (; i < mr.rows; i++) {
|
||||
mr.data[i + j * mr.rows] += m1.data[i + n * m1.rows] * val_m2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user