Skip to content

Commit 44f906e

Browse files
committed
metal : add f16 support
1 parent d5b111f commit 44f906e

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

ggml-metal.m

+13-10
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@
4747
GGML_METAL_DECL_KERNEL(relu);
4848
GGML_METAL_DECL_KERNEL(soft_max);
4949
GGML_METAL_DECL_KERNEL(diag_mask_inf);
50+
GGML_METAL_DECL_KERNEL(get_rows_f16);
5051
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
5152
GGML_METAL_DECL_KERNEL(rms_norm);
52-
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
5353
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
54+
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
5455
GGML_METAL_DECL_KERNEL(rope);
5556
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
5657
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -130,10 +131,11 @@
130131
GGML_METAL_ADD_KERNEL(relu);
131132
GGML_METAL_ADD_KERNEL(soft_max);
132133
GGML_METAL_ADD_KERNEL(diag_mask_inf);
134+
GGML_METAL_ADD_KERNEL(get_rows_f16);
133135
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
134136
GGML_METAL_ADD_KERNEL(rms_norm);
135-
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
136137
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
138+
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
137139
GGML_METAL_ADD_KERNEL(rope);
138140
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
139141
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -498,6 +500,14 @@ void ggml_metal_graph_compute(
498500

499501
// use custom matrix x vector kernel
500502
switch (src0t) {
503+
case GGML_TYPE_F16:
504+
{
505+
GGML_ASSERT(ne02 == ne12);
506+
507+
nth0 = 64;
508+
nth1 = 1;
509+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
510+
} break;
501511
case GGML_TYPE_Q4_0:
502512
{
503513
GGML_ASSERT(ne02 == 1);
@@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
507517
nth1 = 4;
508518
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
509519
} break;
510-
case GGML_TYPE_F16:
511-
{
512-
GGML_ASSERT(ne02 == ne12);
513-
514-
nth0 = 32;
515-
nth1 = 1;
516-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
517-
} break;
518520
default: GGML_ASSERT(false && "not implemented");
519521
};
520522

@@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
551553
}
552554

553555
switch (src0->type) {
556+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
554557
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
555558
default: GGML_ASSERT(false && "not implemented");
556559
}

ggml-metal.metal

+16
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,22 @@ kernel void kernel_diag_mask_inf(
169169
}
170170
}
171171

172+
kernel void kernel_get_rows_f16(
173+
device const void * src0,
174+
device const int * src1,
175+
device float * dst,
176+
constant int64_t & ne00,
177+
constant uint64_t & nb01,
178+
constant uint64_t & nb1,
179+
uint tpig[[thread_position_in_grid]]) {
180+
const int i = tpig;
181+
const int r = ((device int32_t *) src1)[i];
182+
183+
for (int j = 0; j < ne00; j++) {
184+
dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
185+
}
186+
}
187+
172188
kernel void kernel_get_rows_q4_0(
173189
device const void * src0,
174190
device const int * src1,

llama.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,6 @@ static void llama_model_load_internal(
961961
model.hparams = ml->file_loaders.at(0)->hparams;
962962
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
963963
auto & hparams = model.hparams;
964-
uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
965964

966965
{
967966
switch (hparams.n_layer) {
@@ -975,6 +974,8 @@ static void llama_model_load_internal(
975974
hparams.n_ctx = n_ctx;
976975
}
977976

977+
const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
978+
978979
{
979980
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
980981
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);

0 commit comments

Comments
 (0)