Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
hrydgard
GitHub Repository: hrydgard/ppsspp
Path: blob/master/Common/Math/fast/fast_matrix.c
3187 views
1
#include "ppsspp_config.h"
2
3
#include "Common/Math/SIMDHeaders.h"
4
5
#include "fast_matrix.h"
6
7
#if PPSSPP_ARCH(SSE2)
8
9
void fast_matrix_mul_4x4_sse(float *dest, const float *a, const float *b) {
10
int i;
11
__m128 a_col_1 = _mm_loadu_ps(a);
12
__m128 a_col_2 = _mm_loadu_ps(&a[4]);
13
__m128 a_col_3 = _mm_loadu_ps(&a[8]);
14
__m128 a_col_4 = _mm_loadu_ps(&a[12]);
15
16
for (i = 0; i < 16; i += 4) {
17
__m128 r_col = _mm_mul_ps(a_col_1, _mm_set1_ps(b[i]));
18
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_2, _mm_set1_ps(b[i + 1])));
19
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_3, _mm_set1_ps(b[i + 2])));
20
r_col = _mm_add_ps(r_col, _mm_mul_ps(a_col_4, _mm_set1_ps(b[i + 3])));
21
_mm_storeu_ps(&dest[i], r_col);
22
}
23
}
24
25
#elif PPSSPP_ARCH(LOONGARCH64_LSX)
26
27
typedef union
28
{
29
int32_t i;
30
float f;
31
} FloatInt;
32
33
static __m128 __lsx_vreplfr2vr_s(float val)
34
{
35
FloatInt tmpval = {.f = val};
36
return (__m128)__lsx_vreplgr2vr_w(tmpval.i);
37
}
38
39
void fast_matrix_mul_4x4_lsx(float *dest, const float *a, const float *b) {
40
__m128 a_col_1 = (__m128)__lsx_vld(a, 0);
41
__m128 a_col_2 = (__m128)__lsx_vld(a + 4, 0);
42
__m128 a_col_3 = (__m128)__lsx_vld(a + 8, 0);
43
__m128 a_col_4 = (__m128)__lsx_vld(a + 12, 0);
44
45
for (int i = 0; i < 16; i += 4) {
46
47
__m128 b1 = __lsx_vreplfr2vr_s(b[i]);
48
__m128 b2 = __lsx_vreplfr2vr_s(b[i + 1]);
49
__m128 b3 = __lsx_vreplfr2vr_s(b[i + 2]);
50
__m128 b4 = __lsx_vreplfr2vr_s(b[i + 3]);
51
52
__m128 result = __lsx_vfmul_s(a_col_1, b1);
53
result = __lsx_vfmadd_s(a_col_2, b2, result);
54
result = __lsx_vfmadd_s(a_col_3, b3, result);
55
result = __lsx_vfmadd_s(a_col_4, b4, result);
56
57
__lsx_vst(result, &dest[i], 0);
58
}
59
}
60
61
#elif PPSSPP_ARCH(ARM_NEON)
62
63
// From https://developer.arm.com/documentation/102467/0100/Matrix-multiplication-example
64
void fast_matrix_mul_4x4_neon(float *C, const float *A, const float *B) {
65
// these are the columns A
66
float32x4_t A0;
67
float32x4_t A1;
68
float32x4_t A2;
69
float32x4_t A3;
70
71
// these are the columns B
72
float32x4_t B0;
73
float32x4_t B1;
74
float32x4_t B2;
75
float32x4_t B3;
76
77
// these are the columns C
78
float32x4_t C0;
79
float32x4_t C1;
80
float32x4_t C2;
81
float32x4_t C3;
82
83
A0 = vld1q_f32(A);
84
A1 = vld1q_f32(A + 4);
85
A2 = vld1q_f32(A + 8);
86
A3 = vld1q_f32(A + 12);
87
88
// Multiply accumulate in 4x1 blocks, i.e. each column in C
89
B0 = vld1q_f32(B);
90
C0 = vmulq_laneq_f32(A0, B0, 0);
91
C0 = vfmaq_laneq_f32(C0, A1, B0, 1);
92
C0 = vfmaq_laneq_f32(C0, A2, B0, 2);
93
C0 = vfmaq_laneq_f32(C0, A3, B0, 3);
94
vst1q_f32(C, C0);
95
96
B1 = vld1q_f32(B + 4);
97
C1 = vmulq_laneq_f32(A0, B1, 0);
98
C1 = vfmaq_laneq_f32(C1, A1, B1, 1);
99
C1 = vfmaq_laneq_f32(C1, A2, B1, 2);
100
C1 = vfmaq_laneq_f32(C1, A3, B1, 3);
101
vst1q_f32(C + 4, C1);
102
103
B2 = vld1q_f32(B + 8);
104
C2 = vmulq_laneq_f32(A0, B2, 0);
105
C2 = vfmaq_laneq_f32(C2, A1, B2, 1);
106
C2 = vfmaq_laneq_f32(C2, A2, B2, 2);
107
C2 = vfmaq_laneq_f32(C2, A3, B2, 3);
108
vst1q_f32(C + 8, C2);
109
110
B3 = vld1q_f32(B + 12);
111
C3 = vmulq_laneq_f32(A0, B3, 0);
112
C3 = vfmaq_laneq_f32(C3, A1, B3, 1);
113
C3 = vfmaq_laneq_f32(C3, A2, B3, 2);
114
C3 = vfmaq_laneq_f32(C3, A3, B3, 3);
115
vst1q_f32(C + 12, C3);
116
}
117
118
#else
119
120
#define xx 0
121
#define xy 1
122
#define xz 2
123
#define xw 3
124
#define yx 4
125
#define yy 5
126
#define yz 6
127
#define yw 7
128
#define zx 8
129
#define zy 9
130
#define zz 10
131
#define zw 11
132
#define wx 12
133
#define wy 13
134
#define wz 14
135
#define ww 15
136
137
void fast_matrix_mul_4x4_c(float *dest, const float *a, const float *b) {
138
dest[xx] = b[xx] * a[xx] + b[xy] * a[yx] + b[xz] * a[zx] + b[xw] * a[wx];
139
dest[xy] = b[xx] * a[xy] + b[xy] * a[yy] + b[xz] * a[zy] + b[xw] * a[wy];
140
dest[xz] = b[xx] * a[xz] + b[xy] * a[yz] + b[xz] * a[zz] + b[xw] * a[wz];
141
dest[xw] = b[xx] * a[xw] + b[xy] * a[yw] + b[xz] * a[zw] + b[xw] * a[ww];
142
143
dest[yx] = b[yx] * a[xx] + b[yy] * a[yx] + b[yz] * a[zx] + b[yw] * a[wx];
144
dest[yy] = b[yx] * a[xy] + b[yy] * a[yy] + b[yz] * a[zy] + b[yw] * a[wy];
145
dest[yz] = b[yx] * a[xz] + b[yy] * a[yz] + b[yz] * a[zz] + b[yw] * a[wz];
146
dest[yw] = b[yx] * a[xw] + b[yy] * a[yw] + b[yz] * a[zw] + b[yw] * a[ww];
147
148
dest[zx] = b[zx] * a[xx] + b[zy] * a[yx] + b[zz] * a[zx] + b[zw] * a[wx];
149
dest[zy] = b[zx] * a[xy] + b[zy] * a[yy] + b[zz] * a[zy] + b[zw] * a[wy];
150
dest[zz] = b[zx] * a[xz] + b[zy] * a[yz] + b[zz] * a[zz] + b[zw] * a[wz];
151
dest[zw] = b[zx] * a[xw] + b[zy] * a[yw] + b[zz] * a[zw] + b[zw] * a[ww];
152
153
dest[wx] = b[wx] * a[xx] + b[wy] * a[yx] + b[wz] * a[zx] + b[ww] * a[wx];
154
dest[wy] = b[wx] * a[xy] + b[wy] * a[yy] + b[wz] * a[zy] + b[ww] * a[wy];
155
dest[wz] = b[wx] * a[xz] + b[wy] * a[yz] + b[wz] * a[zz] + b[ww] * a[wz];
156
dest[ww] = b[wx] * a[xw] + b[wy] * a[yw] + b[wz] * a[zw] + b[ww] * a[ww];
157
}
158
159
#endif
160
161