47
47
GGML_METAL_DECL_KERNEL (relu);
48
48
GGML_METAL_DECL_KERNEL (soft_max);
49
49
GGML_METAL_DECL_KERNEL (diag_mask_inf);
50
+ GGML_METAL_DECL_KERNEL (get_rows_f16);
50
51
GGML_METAL_DECL_KERNEL (get_rows_q4_0);
51
52
GGML_METAL_DECL_KERNEL (rms_norm);
52
- GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
53
53
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32);
54
+ GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
54
55
GGML_METAL_DECL_KERNEL (rope);
55
56
GGML_METAL_DECL_KERNEL (cpy_f32_f16);
56
57
GGML_METAL_DECL_KERNEL (cpy_f32_f32);
130
131
GGML_METAL_ADD_KERNEL (relu);
131
132
GGML_METAL_ADD_KERNEL (soft_max);
132
133
GGML_METAL_ADD_KERNEL (diag_mask_inf);
134
+ GGML_METAL_ADD_KERNEL (get_rows_f16);
133
135
GGML_METAL_ADD_KERNEL (get_rows_q4_0);
134
136
GGML_METAL_ADD_KERNEL (rms_norm);
135
- GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
136
137
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
138
+ GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
137
139
GGML_METAL_ADD_KERNEL (rope);
138
140
GGML_METAL_ADD_KERNEL (cpy_f32_f16);
139
141
GGML_METAL_ADD_KERNEL (cpy_f32_f32);
@@ -498,6 +500,14 @@ void ggml_metal_graph_compute(
498
500
499
501
// use custom matrix x vector kernel
500
502
switch (src0t) {
503
+ case GGML_TYPE_F16:
504
+ {
505
+ GGML_ASSERT (ne02 == ne12);
506
+
507
+ nth0 = 64 ;
508
+ nth1 = 1 ;
509
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
510
+ } break ;
501
511
case GGML_TYPE_Q4_0:
502
512
{
503
513
GGML_ASSERT (ne02 == 1 );
@@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
507
517
nth1 = 4 ;
508
518
[encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_0_f32];
509
519
} break ;
510
- case GGML_TYPE_F16:
511
- {
512
- GGML_ASSERT (ne02 == ne12);
513
-
514
- nth0 = 32 ;
515
- nth1 = 1 ;
516
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
517
- } break ;
518
520
default : GGML_ASSERT (false && " not implemented" );
519
521
};
520
522
@@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
551
553
}
552
554
553
555
switch (src0->type ) {
556
+ case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_get_rows_f16]; break ;
554
557
case GGML_TYPE_Q4_0: [encoder setComputePipelineState: ctx->pipeline_get_rows_q4_0]; break ;
555
558
default : GGML_ASSERT (false && " not implemented" );
556
559
}
0 commit comments