Skip to content

Commit 694af8a

Browse files
committed
metal : still optimizing Q4_K
This commit pushes it down to 25.3 ms / token. The crazy idea of using 6 bits for the scales is really costly on Metal: if I remove the bit fiddling necessary to make the block scales, time goes almost to the Q4_0 23 ms/token. Before pushing the k-quants upstream I had a Q4_K variant that had used 8-bit scales. It wasn't more accurate, used 0.125 bits more per weight, was running slightly slower on the CPU (due to the larger model size and being memory bound there), and the difference was entirely negligible under CUDA. So, I decided to publish the version with 6-bit scales. Perhaps I should re-consider and change to 8-bit scales?
1 parent 95ec7f0 commit 694af8a

File tree

1 file changed

+54
-14
lines changed

1 file changed

+54
-14
lines changed

ggml-metal.metal

+54-14
Original file line numberDiff line numberDiff line change
@@ -986,42 +986,82 @@ kernel void kernel_mul_mat_q4_k_f32(
986986

987987
const int tid = tpitg.y; // 0...16
988988
const int il = tid/4; // 0...3
989-
const int ir = tid%4; // 0...3
989+
//const int ir = tid%4; // 0...3
990+
const int ir = tid - 4*il;// 0...3
990991
const int n = 4;
991992

992993
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
993994
const int in = il%2;
995+
994996
const int l0 = n*(2*ir + in);
997+
const int q_offset = 32*im + l0;
998+
const int y_offset = 64*im + l0;
995999

9961000
sum[ith] = 0.0f;
9971001

1002+
//uint16_t aux_scales[4];
1003+
//thread uint8_t * sc = (thread uint8_t *)aux_scales;
1004+
1005+
//uint32_t aux32[4];
1006+
//thread const uint8_t * sc = (thread const uint8_t *)aux32;
1007+
1008+
uchar2 sc1, sc2, sc3, sc4;
1009+
9981010
float sumf = 0;
9991011
for (int i = tpitg.x; i < nb; i += tptg.x) {
10001012

1001-
device const uint8_t * q1 = (x + i)->qs + 32*im + l0;
1002-
device const float * y1 = yy + i*QK_K + 64*im + l0;
1013+
device const uint8_t * q1 = (x + i)->qs + q_offset;
10031014
device const uint8_t * q2 = q1 + 64;
1015+
device const float * y1 = yy + i*QK_K + y_offset;
10041016
device const float * y2 = y1 + 128;
10051017

1006-
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1007-
10081018
const float dall = (float)((x + i)->d);
10091019
const float dmin = (float)((x + i)->dmin);
10101020

1011-
const uchar2 sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1012-
const uchar2 sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1013-
const uchar2 sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1014-
const uchar2 sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1021+
//device const uint32_t * a = (device const uint32_t *)(x + i)->scales;
1022+
//aux32[0] = a[0] & 0x3f3f3f3f; // scales for 0, 32, 64, 96
1023+
//aux32[1] = a[1] & 0x3f3f3f3f; // mins for 0, 32, 64, 96
1024+
//aux32[2] = ((a[2] >> 0) & 0x0f0f0f0f) | ((a[0] & 0xc0c0c0c0) >> 2); // scales for 128, 160, 192, 224
1025+
//aux32[3] = ((a[2] >> 4) & 0x0f0f0f0f) | ((a[1] & 0xc0c0c0c0) >> 2); // mins for 128, 160, 192, 224
1026+
1027+
//aux_scales[0] = (uint16_t)(a[im+0] & kmask1);
1028+
//aux_scales[1] = (uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2));
1029+
//aux_scales[2] = (uint16_t)(a[im+2] & kmask1);
1030+
//aux_scales[3] = (uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2));
1031+
1032+
device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1033+
sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1034+
sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1035+
sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1036+
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
10151037

1016-
float2 s = {0.f, 0.f};
1038+
//float2 s = {0.f, 0.f};
1039+
float4 s = {0.f, 0.f, 0.f, 0.f};
1040+
float smin = 0;
10171041
for (int l = 0; l < n; ++l) {
1018-
s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4)
1019-
+ y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4);
1020-
s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1042+
1043+
////s[0] += y1[l] * sc[0] * (q1[l] & 0xF) + y1[l+32] * sc[1] * (q1[l] >> 4)
1044+
//// + y2[l] * sc[2] * (q2[l] & 0xF) + y2[l+32] * sc[3] * (q2[l] >> 4);
1045+
////s[1] += y1[l] * sc[4] + y1[l+32] * sc[5] + y2[l] * sc[6] + y2[l+32] * sc[7];
1046+
1047+
////s[0] += y1[l] * sc[2*im+0] * (q1[l] & 0xF) + y1[l+32] * sc[2*im+1] * (q1[l] >> 4)
1048+
//// + y2[l] * sc[2*im+8] * (q2[l] & 0xF) + y2[l+32] * sc[2*im+9] * (q2[l] >> 4);
1049+
////s[1] += y1[l] * sc[2*im+4] + y1[l+32] * sc[2*im+5] + y2[l] * sc[2*im+12] + y2[l+32] * sc[2*im+13];
1050+
1051+
//s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4)
1052+
// + y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4);
1053+
//s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1054+
1055+
s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1056+
s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1057+
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1058+
10211059
}
1022-
sumf += dall * s[0] - dmin * s[1];
1060+
//sumf += dall * s[0] - dmin * s[1];
1061+
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
10231062

10241063
}
1064+
10251065
sum[ith] = sumf;
10261066

10271067
//

0 commit comments

Comments
 (0)