Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
245f391
graph : reuse hybrid graphs
ggerganov Oct 9, 2025
638e2c2
graph : reuse recurrent graphs
ggerganov Oct 9, 2025
0b9c1ae
metal : fix mul-mm condition + fix mul-mv permuted kernels
ggerganov Oct 9, 2025
1f02d93
graph : fix reuse check for recurrent inputs
ggerganov Oct 10, 2025
00f115f
memory : move the recurrent state into the memory context
ggerganov Oct 10, 2025
2744d61
Revert "memory : move the recurrent state into the memory context"
ggerganov Oct 10, 2025
ab3f3fe
Merge branch 'gg/metal-mul-mat-fixes' into gg/graph-mamba-reuse
gabe-l-hart Oct 10, 2025
8c23c43
Added: tri, cumsum. Still a mess.
gabe-l-hart Oct 10, 2025
2a2e79c
feat(tests): Add --verbose | -v flag to test-backend-ops to print ten…
gabe-l-hart Oct 10, 2025
092f740
test: Add cumsum tests to test-backend-ops
gabe-l-hart Oct 10, 2025
6949ce7
feat(ggml-cpu): Add cumsum support for f16 and bf16
gabe-l-hart Oct 10, 2025
f8fba60
feat(ggml-cpu): Add F16 and BF16 support for tri
gabe-l-hart Oct 13, 2025
058160a
test: Add test cases for tri
gabe-l-hart Oct 13, 2025
86ce3da
chore: TODOs to loosen assertions in tri for ggml_is_contiguous
gabe-l-hart Oct 13, 2025
3a8958f
feat(ggml-metal): Initial (slow) implementation of cumsum for metal
gabe-l-hart Oct 13, 2025
cbaed86
feat(ggml-metal): Add stubs for metal tri
gabe-l-hart Oct 13, 2025
e596469
test: Use looser nmse for lower-precision types for cumsum
gabe-l-hart Oct 13, 2025
3011a6e
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 13, 2025
112d339
test: Allow multiple verbose flags to fully print tensors
gabe-l-hart Oct 15, 2025
78e137f
feat(llama-gguf): Print out the tensor type in llama-gguf r
gabe-l-hart Sep 26, 2025
e5587cb
feat(ggml-metal): Efficient implementation of cumsum for metal
gabe-l-hart Oct 15, 2025
0468b99
test: More verbose printing and better cumsum tests
gabe-l-hart Oct 15, 2025
c71e35e
fix(ggml-metal): better granularity for support bool for CUMSUM and TRI
gabe-l-hart Oct 15, 2025
5f0d2a1
feat(ggml-metal): Metal impl of tri
gabe-l-hart Oct 15, 2025
426580d
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 15, 2025
ba3b8db
fix(ggml-cpu): Fix warnings from build with gcc
gabe-l-hart Oct 15, 2025
dfae909
feat(ggml-cuda): common implementation of prefix sum
gabe-l-hart Oct 16, 2025
d1f8658
feat(ggml-cuda): CUDA implementation of CUMSUM
gabe-l-hart Oct 16, 2025
5071fbd
feat(ggml-cuda): CUDA implementation of TRI
gabe-l-hart Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(ggml-cuda): CUDA implementation of CUMSUM
Branch: Mamba2SSD

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Oct 16, 2025
commit d1f86582ed31649005c57ce886544571e2bb0f8e
126 changes: 126 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "cumsum.cuh"

// Kernel to compute cumulative sum along the innermost dimension (ne[0])
// Each block processes one row (ne[0] elements)
// Algorithm matches Metal implementation:
// 1. Each warp computes prefix sum within itself
// 2. Last thread of each warp stores result in shared memory
// 3. All warps sync
// 4. Each element adds the sum of all preceding warps

template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {

// Shared memory to store warp sums (always use float for accumulation)
extern __shared__ float shmem[];

const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;

if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);

const int tid = threadIdx.x;
const int lane_id = tid % WARP_SIZE;

// Phase 1: Each thread processes elements at stride blockDim.x
// Compute warp-level prefix sums
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
// Load value and compute prefix sum within warp
float val = static_cast<float>(src_row[i0]);
val = warp_prefix_inclusive_sum(val);
dst_row[i0] = static_cast<T>(val);

// Last thread of warp stores its sum to shared memory at position based on data index
if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) {
const int shmem_idx = i0 / WARP_SIZE;
shmem[shmem_idx] = val;
}
}

// Sync once after all warp prefix sums are computed
__syncthreads();

// Phase 2: Add the sum of all preceding warp groups to each element
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
const int shmem_idx = i0 / WARP_SIZE;
float sum = 0.0f;
for (int j = 0; j < shmem_idx; ++j) {
sum += shmem[j];
}
dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum);
}
}

template<typename T>
static void cumsum_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
cudaStream_t stream) {

dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1);
dim3 grid_dims(ne01, ne02, ne03);

// Shared memory size: one float per warp
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
const size_t shmem_size = num_warps * sizeof(float);

cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00, nb01, nb02, nb03,
nb0, nb1, nb2, nb3
);
}

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == dst->type);
switch(src0->type) {
case GGML_TYPE_F32:
{
cumsum_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_F16:
{
cumsum_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_BF16:
{
cumsum_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
default:
GGML_ABORT("fatal error");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_CUMSUM_BLOCK_SIZE 256

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/fattn.cuh"
#include "ggml-cuda/getrows.cuh"
Expand Down Expand Up @@ -2512,6 +2513,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
Expand Down Expand Up @@ -3650,6 +3654,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_CUMSUM:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
Expand Down