Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
245f391
graph : reuse hybrid graphs
ggerganov Oct 9, 2025
638e2c2
graph : reuse recurrent graphs
ggerganov Oct 9, 2025
0b9c1ae
metal : fix mul-mm condition + fix mul-mv permuted kernels
ggerganov Oct 9, 2025
1f02d93
graph : fix reuse check for recurrent inputs
ggerganov Oct 10, 2025
00f115f
memory : move the recurrent state into the memory context
ggerganov Oct 10, 2025
2744d61
Revert "memory : move the recurrent state into the memory context"
ggerganov Oct 10, 2025
ab3f3fe
Merge branch 'gg/metal-mul-mat-fixes' into gg/graph-mamba-reuse
gabe-l-hart Oct 10, 2025
8c23c43
Added: tri, cumsum. Still a mess.
gabe-l-hart Oct 10, 2025
2a2e79c
feat(tests): Add --verbose | -v flag to test-backend-ops to print ten…
gabe-l-hart Oct 10, 2025
092f740
test: Add cumsum tests to test-backend-ops
gabe-l-hart Oct 10, 2025
6949ce7
feat(ggml-cpu): Add cumsum support for f16 and bf16
gabe-l-hart Oct 10, 2025
f8fba60
feat(ggml-cpu): Add F16 and BF16 support for tri
gabe-l-hart Oct 13, 2025
058160a
test: Add test cases for tri
gabe-l-hart Oct 13, 2025
86ce3da
chore: TODOs to loosen assertions in tri for ggml_is_contiguous
gabe-l-hart Oct 13, 2025
3a8958f
feat(ggml-metal): Initial (slow) implementation of cumsum for metal
gabe-l-hart Oct 13, 2025
cbaed86
feat(ggml-metal): Add stubs for metal tri
gabe-l-hart Oct 13, 2025
e596469
test: Use looser nmse for lower-precision types for cumsum
gabe-l-hart Oct 13, 2025
3011a6e
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 13, 2025
112d339
test: Allow multiple verbose flags to fully print tensors
gabe-l-hart Oct 15, 2025
78e137f
feat(llama-gguf): Print out the tensor type in llama-gguf r
gabe-l-hart Sep 26, 2025
e5587cb
feat(ggml-metal): Efficient implementation of cumsum for metal
gabe-l-hart Oct 15, 2025
0468b99
test: More verbose printing and better cumsum tests
gabe-l-hart Oct 15, 2025
c71e35e
fix(ggml-metal): better granularity for support bool for CUMSUM and TRI
gabe-l-hart Oct 15, 2025
5f0d2a1
feat(ggml-metal): Metal impl of tri
gabe-l-hart Oct 15, 2025
426580d
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 15, 2025
ba3b8db
fix(ggml-cpu): Fix warnings from build with gcc
gabe-l-hart Oct 15, 2025
dfae909
feat(ggml-cuda): common implementation of prefix sum
gabe-l-hart Oct 16, 2025
d1f8658
feat(ggml-cuda): CUDA implementation of CUMSUM
gabe-l-hart Oct 16, 2025
5071fbd
feat(ggml-cuda): CUDA implementation of TRI
gabe-l-hart Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
feat(ggml-cuda): CUDA implementation of TRI
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Oct 16, 2025
commit 5071fbd5786558ab21127489069895a58c4bb838
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/topk-moe.cuh"
#include "ggml-cuda/tri.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv.cuh"
Expand Down Expand Up @@ -2516,6 +2517,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
Expand Down Expand Up @@ -3655,6 +3659,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
Expand Down
109 changes: 109 additions & 0 deletions ggml/src/ggml-cuda/tri.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include "tri.cuh"
#include "ggml.h"
#include <cmath>

// Triangle type comparison - determines which elements to keep
__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) {
switch (type) {
case GGML_TRI_TYPE_LOWER: return i < r;
case GGML_TRI_TYPE_LOWER_DIAG: return i <= r;
case GGML_TRI_TYPE_UPPER: return i > r;
case GGML_TRI_TYPE_UPPER_DIAG: return i >= r;
default: return false;
}
}

template<typename T>
static __global__ void tri_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const float c, const ggml_tri_type ttype) {

const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;

if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);

const bool keep_org_val = isnan(c);

// Each thread processes elements at stride blockDim.x
for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = tri_compare(i0, i1, ttype)
? (keep_org_val ? src_row[i0] : static_cast<T>(c))
: static_cast<T>(0.f);
}
}

template<typename T>
static void tri_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const float c, const ggml_tri_type ttype,
cudaStream_t stream) {

dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
dim3 grid_dims(ne01, ne02, ne03);

tri_kernel<<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00, nb01, nb02, nb03,
nb0, nb1, nb2, nb3,
c, ttype
);
}

void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
const float c = ggml_get_op_params_f32(dst, 1);

GGML_ASSERT(src0->type == dst->type);

switch(src0->type) {
case GGML_TYPE_F32:
{
tri_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
c, ttype, stream
);
} break;
case GGML_TYPE_F16:
{
tri_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
c, ttype, stream
);
} break;
case GGML_TYPE_BF16:
{
tri_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
c, ttype, stream
);
} break;
default:
GGML_ABORT("fatal error");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/tri.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_TRI_BLOCK_SIZE 256

void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Loading