@@ -688,7 +688,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
688
688
#endif
689
689
}
690
690
691
- static void quantize_row_q4_1 (const float * restrict x , void * restrict vy , int k ) {
691
+ static void quantize_row_q4_1_reference (const float * restrict x , void * restrict vy , int k ) {
692
692
assert (k % QK == 0 );
693
693
const int nb = k / QK ;
694
694
@@ -729,6 +729,93 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
729
729
}
730
730
}
731
731
732
+ static void quantize_row_q4_1 (const float * restrict x , void * restrict vy , int k ) {
733
+ assert (k % QK == 0 );
734
+
735
+ #if defined(__AVX2__ )
736
+ const int nb = k / QK ;
737
+
738
+ block_q4_1 * restrict y = vy ;
739
+
740
+ for (int i = 0 ; i < nb ; i ++ ) {
741
+ // Load elements into 4 AVX vectors
742
+ __m256 v0 = _mm256_loadu_ps ( x );
743
+ __m256 v1 = _mm256_loadu_ps ( x + 8 );
744
+ __m256 v2 = _mm256_loadu_ps ( x + 16 );
745
+ __m256 v3 = _mm256_loadu_ps ( x + 24 );
746
+ x += 32 ;
747
+
748
+ // Compute max for the block
749
+ __m256 vmax ;
750
+ vmax = _mm256_max_ps ( v0 , v1 );
751
+ vmax = _mm256_max_ps ( vmax , v2 );
752
+ vmax = _mm256_max_ps ( vmax , v3 );
753
+
754
+ __m128 max4 = _mm_max_ps ( _mm256_extractf128_ps ( vmax , 1 ), _mm256_castps256_ps128 ( vmax ) );
755
+ max4 = _mm_max_ps ( max4 , _mm_movehl_ps ( max4 , max4 ) );
756
+ max4 = _mm_max_ss ( max4 , _mm_movehdup_ps ( max4 ) );
757
+ const float maxScalar = _mm_cvtss_f32 ( max4 );
758
+
759
+ // Compute min for the block
760
+ __m256 vmin ;
761
+ vmin = _mm256_min_ps ( v0 , v1 );
762
+ vmin = _mm256_min_ps ( vmin , v2 );
763
+ vmin = _mm256_min_ps ( vmin , v3 );
764
+
765
+ __m128 min4 = _mm_min_ps ( _mm256_extractf128_ps ( vmin , 1 ), _mm256_castps256_ps128 ( vmin ) );
766
+ min4 = _mm_min_ps ( min4 , _mm_movehl_ps ( min4 , min4 ) );
767
+ min4 = _mm_min_ss ( min4 , _mm_movehdup_ps ( min4 ) );
768
+ const float minScalar = _mm_cvtss_f32 ( min4 );
769
+
770
+ // Quantize these floats
771
+ const float d = (maxScalar - minScalar ) / ((1 << 4 ) - 1 );
772
+ const float id = d ? 1.0f /d : 0.0f ;
773
+
774
+ y [i ].m = minScalar ;
775
+ y [i ].d = d ;
776
+
777
+ // x = (x-min)*id
778
+ const __m256 mul = _mm256_set1_ps ( id );
779
+ const __m256 off = _mm256_set1_ps ( minScalar );
780
+ v0 = _mm256_mul_ps ( _mm256_sub_ps ( v0 , off ), mul );
781
+ v1 = _mm256_mul_ps ( _mm256_sub_ps ( v1 , off ), mul );
782
+ v2 = _mm256_mul_ps ( _mm256_sub_ps ( v2 , off ), mul );
783
+ v3 = _mm256_mul_ps ( _mm256_sub_ps ( v3 , off ), mul );
784
+
785
+ // Round to nearest integer
786
+ v0 = _mm256_round_ps ( v0 , _MM_ROUND_NEAREST );
787
+ v1 = _mm256_round_ps ( v1 , _MM_ROUND_NEAREST );
788
+ v2 = _mm256_round_ps ( v2 , _MM_ROUND_NEAREST );
789
+ v3 = _mm256_round_ps ( v3 , _MM_ROUND_NEAREST );
790
+
791
+ // Convert floats to integers
792
+ __m256i i0 = _mm256_cvtps_epi32 ( v0 );
793
+ __m256i i1 = _mm256_cvtps_epi32 ( v1 );
794
+ __m256i i2 = _mm256_cvtps_epi32 ( v2 );
795
+ __m256i i3 = _mm256_cvtps_epi32 ( v3 );
796
+
797
+ // Convert int32 to int16
798
+ i0 = _mm256_packs_epi32 ( i0 , i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
799
+ i2 = _mm256_packs_epi32 ( i2 , i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
800
+ // Convert int16 to int8
801
+ i0 = _mm256_packs_epi16 ( i0 , i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
802
+
803
+ // We got our precious signed bytes, but the order is now wrong
804
+ // These AVX2 pack instructions process 16-byte pieces independently
805
+ // The following instruction is fixing the order
806
+ const __m256i perm = _mm256_setr_epi32 ( 0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 );
807
+ i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
808
+
809
+ // Compress the vector into 4 bit/value, and store
810
+ __m128i res = packNibbles ( i0 );
811
+ _mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
812
+ }
813
+ #else
814
+ // scalar
815
+ quantize_row_q4_1_reference (x , vy , k );
816
+ #endif
817
+ }
818
+
732
819
static void dequantize_row_q4_0 (const void * restrict vx , float * restrict y , int k ) {
733
820
assert (k % QK == 0 );
734
821
const int nb = k / QK ;
@@ -10135,7 +10222,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
10135
10222
for (int j = 0 ; j < n ; j += k ) {
10136
10223
block_q4_1 * restrict y = (block_q4_1 * )dst + j /QK ;
10137
10224
10138
- quantize_row_q4_1 (src + j , y , k );
10225
+ quantize_row_q4_1_reference (src + j , y , k );
10139
10226
10140
10227
for (int i = 0 ; i < nb ; i ++ ) {
10141
10228
for (int l = 0 ; l < QK ; l += 2 ) {
0 commit comments