话不多说,直接贴代码如下:
#include<arm_neon.h>
inline void mat_mul(DATA_TYPE *A, DATA_TYPE *B, DATA_TYPE *C, int M, int N, int K)
{
#if 0
int mStep = 4;
int MalignedStep = (M / mStep) * mStep;
int m = 0;
for (m = 0; m < MalignedStep; m += mStep)
{
for (int p = 0; p < N; p++)
{
for (int q = 0; q < mStep; q++)
{
C[(m + q) * N + p] = 0;
}
}
int kStep = 4;
int KalignedStep = (K / kStep) * kStep;
int k = 0;
for (k = 0; k < KalignedStep; k += kStep)
{
for (int n = 0; n < N; n += 1)
{
for (int q = 0; q < mStep; q++)
{
for (int p = 0; p < kStep; p++)
{
C[(m + q) * N + n] += A[(m + q) * K + k + p] * B[(k + p) * N + n];
}
}
}
}
for (; k < K; k += 1)
{
for (int n = 0; n < N; n += 1)
{
for (int q = 0; q < mStep; q++)
{
C[(m + q) * N + n] += A[(m + q) * K + k] * B[k * N + n];
}
}
}
}
for (; m < M; m += 1)
{
for (int p = 0; p < N; p++)
{
C[m * N + p] = 0;
}
int kStep = 4;
int KalignedStep = (K / kStep) * kStep;
int k = 0;
for (k = 0; k < KalignedStep; k += kStep)
{
DATA_TYPE A_data[kStep];
for (int p = 0; p < kStep; p++)
{
A_data[p] = A[m * K + k + p];
}
for (int n = 0; n < N; n += 1)
{
for (int p = 0; p < kStep; p++)
{
C[m * N + n] += A_data[p] * B[(k + p) * N + n];
}
}
}
for (; k < K; k += 1)
{
DATA_TYPE A_data;
A_data = A[m * K + k];
for (int n = 0; n < N; n += 1)
{
C[m * N + n] += A_data * B[k * N + n];
}
}
}
#else
int mStep = 4;
int MalignedStep = (M / mStep) * mStep;
int m = 0;
for (m = 0; m < MalignedStep; m += mStep)
{
for (int p = 0; p < N; p++)
{
C[m * N + p] = 0;
C[(m + 1) * N + p] = 0;
C[(m + 2) * N + p] = 0;
C[(m + 3) * N + p] = 0;
}
int kStep = 4;
int KalignedStep = (K / kStep) * kStep;
int k = 0;
for (k = 0; k < KalignedStep; k += kStep)
{
float32x4_t A_data1 = vdupq_n_f32(A[m * K + k]);
float32x4_t A_data2 = vdupq_n_f32(A[m * K + k + 1]);
float32x4_t A_data3 = vdupq_n_f32(A[m * K + k + 2]);
float32x4_t A_data4 = vdupq_n_f32(A[m * K + k + 3]);
float32x4_t A_data21 = vdupq_n_f32(A[(m + 1) * K + k]);
float32x4_t A_data22 = vdupq_n_f32(A[(m + 1) * K + k + 1]);
float32x4_t A_data23 = vdupq_n_f32(A[(m + 1) * K + k + 2]);
float32x4_t A_data24 = vdupq_n_f32(A[(m + 1) * K + k + 3]);
float32x4_t A_data31 = vdupq_n_f32(A[(m + 2) * K + k]);
float32x4_t A_data32 = vdupq_n_f32(A[(m + 2) * K + k + 1]);
float32x4_t A_data33 = vdupq_n_f32(A[(m + 2) * K + k + 2]);
float32x4_t A_data34 = vdupq_n_f32(A[(m + 2) * K + k + 3]);
float32x4_t A_data41 = vdupq_n_f32(A[(m + 3) * K + k]);
float32x4_t A_data42 = vdupq_n_f32(A[(m + 3) * K + k + 1]);
float32x4_t A_data43 = vdupq_n_f32(A[(m + 3) * K + k + 2]);
float32x4_t A_data44 = vdupq_n_f32(A[(m + 3) * K + k + 3]);
float *C_ptr1 = C + m * N;
float *C_ptr2 = C + (m + 1) * N;
float *C_ptr3 = C + (m + 2) * N;
float *C_ptr4 = C + (m + 3) * N;
for (int n = 0; n < N; n += 4)
{
float32x4_t B_data1 = vld1q_f32(B + (k + 0) * N + n);
float32x4_t B_data2 = vld1q_f32(B + (k + 1) * N + n);
float32x4_t B_data3 = vld1q_f32(B + (k + 2) * N + n);
float32x4_t B_data4 = vld1q_f32(B + (k + 3) * N + n);
float32x4_t C_data = vld1q_f32(C_ptr1 + n);
float32x4_t C_data2 = vld1q_f32(C_ptr2 + n);
float32x4_t C_data3 = vld1q_f32(C_ptr3 + n);
float32x4_t C_data4 = vld1q_f32(C_ptr4 + n);
C_data = vmlaq_f32(C_data, A_data1, B_data1);
C_data = vmlaq_f32(C_data, A_data2, B_data2);
C_data = vmlaq_f32(C_data, A_data3, B_data3);
C_data = vmlaq_f32(C_data, A_data4, B_data4);
C_data2 = vmlaq_f32(C_data2, A_data21, B_data1);
C_data2 = vmlaq_f32(C_data2, A_data22, B_data2);
C_data2 = vmlaq_f32(C_data2, A_data23, B_data3);
C_data2 = vmlaq_f32(C_data2, A_data24, B_data4);
C_data3 = vmlaq_f32(C_data3, A_data31, B_data1);
C_data3 = vmlaq_f32(C_data3, A_data32, B_data2);
C_data3 = vmlaq_f32(C_data3, A_data33, B_data3);
C_data3 = vmlaq_f32(C_data3, A_data34, B_data4);
C_data4 = vmlaq_f32(C_data4, A_data41, B_data1);
C_data4 = vmlaq_f32(C_data4, A_data42, B_data2);
C_data4 = vmlaq_f32(C_data4, A_data43, B_data3);
C_data4 = vmlaq_f32(C_data4, A_data44, B_data4);
vst1q_f32(C_ptr1 + n, C_data);
vst1q_f32(C_ptr2 + n, C_data2);
vst1q_f32(C_ptr3 + n, C_data3);
vst1q_f32(C_ptr4 + n, C_data4);
}
}
for (; k < K; k += 1)
{
DATA_TYPE A_data;
A_data = A[m * K + k];
DATA_TYPE A_data2;
A_data2 = A[(m + 1) * K + k];
DATA_TYPE A_data3;
A_data3 = A[(m + 2) * K + k];
DATA_TYPE A_data4;
A_data4 = A[(m + 3) * K + k];
for (int n = 0; n < N; n += 1)
{
C[m * N + n] += A_data * B[k * N + n];
C[(m + 1) * N + n] += A_data2 * B[k * N + n];
C[(m + 2) * N + n] += A_data3 * B[k * N + n];
C[(m + 3) * N + n] += A_data4 * B[k * N + n];
}
}
}
for (; m < M; m += 1)
{
for (int p = 0; p < N; p++)
{
C[m * N + p] = 0;
}
int kStep = 4;
int KalignedStep = (K / kStep) * kStep;
int k = 0;
for (k = 0; k < KalignedStep; k += kStep)
{
float32x4_t A_data1 = vdupq_n_f32(A[m * K + k]);
float32x4_t A_data2 = vdupq_n_f32(A[m * K + k + 1]);
float32x4_t A_data3 = vdupq_n_f32(A[m * K + k + 2]);
float32x4_t A_data4 = vdupq_n_f32(A[m * K + k + 3]);
float *C_ptr1 = C + m * N;
for (int n = 0; n < N; n += 4)
{
float32x4_t B_data1 = vld1q_f32(B + (k + 0) * N + n);
float32x4_t B_data2 = vld1q_f32(B + (k + 1) * N + n);
float32x4_t B_data3 = vld1q_f32(B + (k + 2) * N + n);
float32x4_t B_data4 = vld1q_f32(B + (k + 3) * N + n);
float32x4_t C_data = vld1q_f32(C_ptr1 + n);
C_data = vmlaq_f32(C_data, A_data1, B_data1);
C_data = vmlaq_f32(C_data, A_data2, B_data2);
C_data = vmlaq_f32(C_data, A_data3, B_data3);
C_data = vmlaq_f32(C_data, A_data4, B_data4);
vst1q_f32(C_ptr1 + n, C_data);
}
}
for (; k < K; k += 1)
{
DATA_TYPE A_data;
A_data = A[m * K + k];
for (int n = 0; n < N; n += 1)
{
C[m * N + n] += A_data * B[k * N + n];
}
}
}
#endif
}