Skip to content

Commit 6bc4400

Browse files
committed
ggml : add Q5 WASM SIMD + GGML_FTYPE
1 parent f0d70f1 commit 6bc4400

File tree

2 files changed

+177
-2
lines changed

2 files changed

+177
-2
lines changed

ggml.c

+160-2
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ static ggml_fp16_t table_exp_f16[1 << 16];
330330
// precomputed f32 table for f16 (256 KB)
331331
static float table_f32_f16[1 << 16];
332332

333-
#if defined(__ARM_NEON)
333+
#if defined(__ARM_NEON) || defined(__wasm_simd128__)
334334
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
335335
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
336336
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
@@ -1087,7 +1087,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
10871087
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
10881088
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
10891089
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
1090-
const v128_t vc = wasm_i32x4_min_u(vi, wasm_i32x4_splat(15));
1090+
const v128_t vc = wasm_i32x4_min(vi, wasm_i32x4_splat(15));
10911091

10921092
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
10931093
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
@@ -3180,6 +3180,72 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
31803180
}
31813181

31823182
*s = vaddvq_f32(sumv);
3183+
#elif defined(__wasm_simd128__)
3184+
v128_t sumv = wasm_f32x4_splat(0.0f);
3185+
3186+
uint64_t tmp[4];
3187+
3188+
for (int i = 0; i < nb; ++i) {
3189+
const block_q5_0 * restrict x0 = &x[i];
3190+
const block_q8_0 * restrict y0 = &y[i];
3191+
3192+
const v128_t m4b = wasm_i8x16_splat(0x0F);
3193+
const v128_t s16b = wasm_i8x16_splat(0x10);
3194+
3195+
// extract the 5th bit
3196+
uint32_t qh;
3197+
memcpy(&qh, x0->qh, sizeof(qh));
3198+
3199+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3200+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3201+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3202+
tmp[3] = table_b2b_u[(qh >> 24) ];
3203+
3204+
const v128_t qhl = wasm_v128_load(tmp + 0);
3205+
const v128_t qhh = wasm_v128_load(tmp + 2);
3206+
3207+
const v128_t v0 = wasm_v128_load(x0->qs);
3208+
3209+
// 4-bit -> 8-bit
3210+
const v128_t v0l = wasm_v128_and (v0, m4b);
3211+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
3212+
3213+
// interleave
3214+
const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
3215+
const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
3216+
3217+
// add high bit and sub 16
3218+
const v128_t v0lf = wasm_i8x16_sub(wasm_v128_or(v0lz, qhl), s16b);
3219+
const v128_t v0hf = wasm_i8x16_sub(wasm_v128_or(v0hz, qhh), s16b);
3220+
3221+
// load y
3222+
const v128_t v1l = wasm_v128_load(y0->qs);
3223+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
3224+
3225+
// int8x16 -> int16x8
3226+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3227+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3228+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3229+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
3230+
3231+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3232+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3233+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3234+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3235+
3236+
const float x0d = GGML_FP16_TO_FP32(x0->d);
3237+
3238+
// dot product
3239+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3240+
wasm_i32x4_add(
3241+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3242+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3243+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3244+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
3245+
}
3246+
3247+
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3248+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
31833249
#elif defined(__AVX2__)
31843250
// Initialize accumulator with zeros
31853251
__m256 acc = _mm256_setzero_ps();
@@ -3311,6 +3377,77 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
33113377
}
33123378

33133379
*s = vaddvq_f32(sumv) + summs;
3380+
#elif defined(__wasm_simd128__)
3381+
v128_t sumv = wasm_f32x4_splat(0.0f);
3382+
3383+
float summs = 0.0f;
3384+
3385+
uint64_t tmp[4];
3386+
3387+
for (int i = 0; i < nb; ++i) {
3388+
const block_q5_1 * restrict x0 = &x[i];
3389+
const block_q8_1 * restrict y0 = &y[i];
3390+
3391+
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
3392+
3393+
const v128_t m4b = wasm_i8x16_splat(0x0F);
3394+
3395+
// extract the 5th bit
3396+
uint32_t qh;
3397+
memcpy(&qh, x0->qh, sizeof(qh));
3398+
3399+
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
3400+
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
3401+
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
3402+
tmp[3] = table_b2b_u[(qh >> 24) ];
3403+
3404+
const v128_t qhl = wasm_v128_load(tmp + 0);
3405+
const v128_t qhh = wasm_v128_load(tmp + 2);
3406+
3407+
const v128_t v0 = wasm_v128_load(x0->qs);
3408+
3409+
// 4-bit -> 8-bit
3410+
const v128_t v0l = wasm_v128_and (v0, m4b);
3411+
const v128_t v0h = wasm_u8x16_shr(v0, 4);
3412+
3413+
static bool x = true;
3414+
3415+
// interleave
3416+
const v128_t v0lz = wasm_v8x16_shuffle(v0l, v0h, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23);
3417+
const v128_t v0hz = wasm_v8x16_shuffle(v0l, v0h, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31);
3418+
3419+
// add high bit
3420+
const v128_t v0lf = wasm_v128_or(v0lz, qhl);
3421+
const v128_t v0hf = wasm_v128_or(v0hz, qhh);
3422+
3423+
// load y
3424+
const v128_t v1l = wasm_v128_load(y0->qs);
3425+
const v128_t v1h = wasm_v128_load(y0->qs + 16);
3426+
3427+
// int8x16 -> int16x8
3428+
const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
3429+
const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
3430+
const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
3431+
const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
3432+
3433+
const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
3434+
const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
3435+
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
3436+
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
3437+
3438+
const float x0d = GGML_FP16_TO_FP32(x0->d);
3439+
3440+
// dot product
3441+
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
3442+
wasm_i32x4_add(
3443+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
3444+
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
3445+
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
3446+
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
3447+
}
3448+
3449+
*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
3450+
wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
33143451
#elif defined(__AVX2__)
33153452
// Initialize accumulator with zeros
33163453
__m256 acc = _mm256_setzero_ps();
@@ -4057,6 +4194,27 @@ bool ggml_is_quantized(enum ggml_type type) {
40574194
return GGML_IS_QUANTIZED[type];
40584195
}
40594196

4197+
enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
4198+
enum ggml_type wtype = GGML_TYPE_COUNT;
4199+
4200+
switch (ftype) {
4201+
case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break;
4202+
case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break;
4203+
case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break;
4204+
case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break;
4205+
case GGML_FTYPE_MOSTLY_Q4_2: wtype = GGML_TYPE_Q4_2; break;
4206+
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
4207+
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
4208+
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
4209+
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
4210+
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
4211+
}
4212+
4213+
GGML_ASSERT(wtype != GGML_TYPE_COUNT);
4214+
4215+
return wtype;
4216+
}
4217+
40604218
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
40614219
return tensor->nb[0] > tensor->nb[1];
40624220
}

ggml.h

+17
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,20 @@ extern "C" {
232232
GGML_TYPE_COUNT,
233233
};
234234

235+
// model file types
236+
enum ggml_ftype {
237+
GGML_FTYPE_UNKNOWN = -1,
238+
GGML_FTYPE_ALL_F32 = 0,
239+
GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
240+
GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
241+
GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
242+
GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
243+
GGML_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors
244+
GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
245+
GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
246+
GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
247+
};
248+
235249
// available tensor operations:
236250
enum ggml_op {
237251
GGML_OP_NONE = 0,
@@ -385,6 +399,9 @@ extern "C" {
385399

386400
GGML_API bool ggml_is_quantized(enum ggml_type type);
387401

402+
// TODO: temporary until model loading of ggml examples is refactored
403+
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
404+
388405
// main
389406

390407
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);

0 commit comments

Comments
 (0)