diff --git a/README.md b/README.md index 9ba6241da6c17..69363e66a8760 100644 --- a/README.md +++ b/README.md @@ -246,7 +246,7 @@ cadaver, cauliflower, cabbage (vegetable), catalpa (tree) and Cailleach. ### Perplexity (Measuring model quality) -You can pass `--perplexity` as a command line option to measure perplexity over the given prompt. For more background, +You can use the `perplexity` example to measure perplexity over the given prompt. For more background, see https://huggingface.co/docs/transformers/perplexity. However, in general, lower perplexity is better for LLMs. #### Latest measurements @@ -269,10 +269,10 @@ Perplexity - model options #### How to run 1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research -2. Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` +2. Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` 3. Output: ``` -Calculating perplexity over 655 chunks +perplexity : calculating perplexity over 655 chunks 24.43 seconds per pass - ETA 4.45 hours [1]4.5970,[2]5.1807,[3]6.0382,... ``` diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f617ba365dd05..75d526d3df603 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -19,7 +19,7 @@ std::vector softmax(const std::vector& logits) { void perplexity(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research - // Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` + // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` auto tokens = ::llama_tokenize(ctx, params.prompt, true); diff --git a/ggml.c b/ggml.c index c9a4e867523b2..c5330d3dcfc11 100644 --- a/ggml.c +++ b/ggml.c @@ -697,7 +697,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { // method 4 // blocks of QK elements // represented with 2 floats (min + delta) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) -void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { +void quantize_row_q4_1_reference(const float * restrict x, void * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; @@ -745,6 +745,102 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { } } +void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { + assert(k % QK == 0); + +#if defined(__AVX2__) + const int nb = k / QK; + const size_t bs = 2*sizeof(float) + QK/2; + + uint8_t * restrict pd = ((uint8_t *)y + 0*bs); + uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float)); + uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float)); + + uint8_t pp[QK/2]; + + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max for the block + __m256 vmax; + vmax = _mm256_max_ps( v0, v1 ); + vmax = _mm256_max_ps( vmax, v2 ); + vmax = _mm256_max_ps( vmax, v3 ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Compute min for the block + __m256 vmin; + vmin = _mm256_min_ps( v0, v1 ); + vmin = _mm256_min_ps( vmin, v2 ); + vmin = _mm256_min_ps( vmin, v3 ); + + __m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) ); + min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) ); + min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) ); + const float minScalar = _mm_cvtss_f32( min4 ); + + // Quantize these floats + const float d = (maxScalar - minScalar) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + *(float *)pm = minScalar; + *(float *)pd = d; + pm += bs; + pd += bs; + + // x = (x-min)*id + const __m256 mul = _mm256_set1_ps( id ); + const __m256 off = _mm256_set1_ps( minScalar ); + v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul ); + v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul ); + v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul ); + v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + // Compress the vector into 4 bit/value, and store + __m128i res = packNibbles( i0 ); + _mm_storeu_si128( ( __m128i* )pb, res ); + + pb += bs; + } +#else + // scalar + quantize_row_q4_1_reference(x, y, k); +#endif +} + // TODO: vectorize void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { assert(k % QK == 0); @@ -10398,7 +10494,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, i uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs); uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float)); - quantize_row_q4_1(src + j, pd, k); + quantize_row_q4_1_reference(src + j, pd, k); for (int i = 0; i < nb; i++) { for (int l = 0; l < qk; l += 2) {