@@ -2462,6 +2462,62 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2462
2462
}
2463
2463
2464
2464
sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2465
+ #elif defined(__AVX2__ )
2466
+ // Initialize accumulator with zeros
2467
+ __m256 acc = _mm256_setzero_ps ();
2468
+
2469
+ // Main loop
2470
+ for (int i = 0 ; i < nb ; ++ i ) {
2471
+ const float * d0 = & x [i ].d ;
2472
+ const float * d1 = & y [i ].d ;
2473
+ const float * m0 = & x [i ].m ;
2474
+
2475
+ const __m256 d0v = _mm256_broadcast_ss ( d0 );
2476
+ const __m256 d1v = _mm256_broadcast_ss ( d1 );
2477
+ const __m256 m0v = _mm256_broadcast_ss ( m0 );
2478
+
2479
+ // Compute combined scales
2480
+ const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2481
+ const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2482
+
2483
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2484
+ const __m256i bx = bytesFromNibbles ( x [i ].qs );
2485
+ const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
2486
+
2487
+ // Get absolute values of x vectors
2488
+ const __m256i ax = _mm256_sign_epi8 ( bx , bx );
2489
+
2490
+ // Sign the values of the y vectors
2491
+ const __m256i sy = _mm256_sign_epi8 ( by , bx );
2492
+
2493
+ // Perform multiplication and create 16-bit values
2494
+ const __m256i dot = _mm256_maddubs_epi16 ( ax , sy );
2495
+ const __m256i ones = _mm256_set1_epi16 ( 1 );
2496
+ const __m256i xy_q = _mm256_madd_epi16 ( ones , dot );
2497
+
2498
+ // Convert to vector of 8 int32_t to 8 floats
2499
+ const __m256 xy = _mm256_cvtepi32_ps ( xy_q );
2500
+
2501
+ // Accumulate d0*d1*x*y
2502
+ acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2503
+
2504
+ // Compute sum of y values
2505
+ const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2506
+ const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2507
+ const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2508
+ const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2509
+
2510
+ // Accumulate d1*m0*y
2511
+ acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
2512
+ }
2513
+
2514
+ // Return horizontal sum of the acc vector
2515
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2516
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2517
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2518
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2519
+
2520
+ sumf = _mm_cvtss_f32 ( res );
2465
2521
#else
2466
2522
// scalar
2467
2523
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments