Skip to content

Commit ad7007a

Browse files
slarenggerganov
authored andcommitted
ggml : AVX2 implementation of ggml_vec_dot_q4_1_q8_0 (#1051)
1 parent 4262305 commit ad7007a

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

ggml.c

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,6 +2462,62 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
24622462
}
24632463

24642464
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2465+
#elif defined(__AVX2__)
2466+
// Initialize accumulator with zeros
2467+
__m256 acc = _mm256_setzero_ps();
2468+
2469+
// Main loop
2470+
for (int i = 0; i < nb; ++i) {
2471+
const float * d0 = &x[i].d;
2472+
const float * d1 = &y[i].d;
2473+
const float * m0 = &x[i].m;
2474+
2475+
const __m256 d0v = _mm256_broadcast_ss( d0 );
2476+
const __m256 d1v = _mm256_broadcast_ss( d1 );
2477+
const __m256 m0v = _mm256_broadcast_ss( m0 );
2478+
2479+
// Compute combined scales
2480+
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2481+
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2482+
2483+
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2484+
const __m256i bx = bytesFromNibbles( x[i].qs );
2485+
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2486+
2487+
// Get absolute values of x vectors
2488+
const __m256i ax = _mm256_sign_epi8( bx, bx );
2489+
2490+
// Sign the values of the y vectors
2491+
const __m256i sy = _mm256_sign_epi8( by, bx );
2492+
2493+
// Perform multiplication and create 16-bit values
2494+
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2495+
const __m256i ones = _mm256_set1_epi16( 1 );
2496+
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2497+
2498+
// Convert to vector of 8 int32_t to 8 floats
2499+
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2500+
2501+
// Accumulate d0*d1*x*y
2502+
acc = _mm256_fmadd_ps( d0d1, xy, acc );
2503+
2504+
// Compute sum of y values
2505+
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2506+
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2507+
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2508+
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2509+
2510+
// Accumulate d1*m0*y
2511+
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2512+
}
2513+
2514+
// Return horizontal sum of the acc vector
2515+
__m128 res = _mm256_extractf128_ps( acc, 1 );
2516+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2517+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2518+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2519+
2520+
sumf = _mm_cvtss_f32( res );
24652521
#else
24662522
// scalar
24672523
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)