Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
245f391
graph : reuse hybrid graphs
ggerganov Oct 9, 2025
638e2c2
graph : reuse recurrent graphs
ggerganov Oct 9, 2025
0b9c1ae
metal : fix mul-mm condition + fix mul-mv permuted kernels
ggerganov Oct 9, 2025
1f02d93
graph : fix reuse check for recurrent inputs
ggerganov Oct 10, 2025
00f115f
memory : move the recurrent state into the memory context
ggerganov Oct 10, 2025
2744d61
Revert "memory : move the recurrent state into the memory context"
ggerganov Oct 10, 2025
ab3f3fe
Merge branch 'gg/metal-mul-mat-fixes' into gg/graph-mamba-reuse
gabe-l-hart Oct 10, 2025
8c23c43
Added: tri, cumsum. Still a mess.
gabe-l-hart Oct 10, 2025
2a2e79c
feat(tests): Add --verbose | -v flag to test-backend-ops to print ten…
gabe-l-hart Oct 10, 2025
092f740
test: Add cumsum tests to test-backend-ops
gabe-l-hart Oct 10, 2025
6949ce7
feat(ggml-cpu): Add cumsum support for f16 and bf16
gabe-l-hart Oct 10, 2025
f8fba60
feat(ggml-cpu): Add F16 and BF16 support for tri
gabe-l-hart Oct 13, 2025
058160a
test: Add test cases for tri
gabe-l-hart Oct 13, 2025
86ce3da
chore: TODOs to loosen assertions in tri for ggml_is_contiguous
gabe-l-hart Oct 13, 2025
3a8958f
feat(ggml-metal): Initial (slow) implementation of cumsum for metal
gabe-l-hart Oct 13, 2025
cbaed86
feat(ggml-metal): Add stubs for metal tri
gabe-l-hart Oct 13, 2025
e596469
test: Use looser nmse for lower-precision types for cumsum
gabe-l-hart Oct 13, 2025
3011a6e
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 13, 2025
112d339
test: Allow multiple verbose flags to fully print tensors
gabe-l-hart Oct 15, 2025
78e137f
feat(llama-gguf): Print out the tensor type in llama-gguf r
gabe-l-hart Sep 26, 2025
e5587cb
feat(ggml-metal): Efficient implementation of cumsum for metal
gabe-l-hart Oct 15, 2025
0468b99
test: More verbose printing and better cumsum tests
gabe-l-hart Oct 15, 2025
c71e35e
fix(ggml-metal): better granularity for support bool for CUMSUM and TRI
gabe-l-hart Oct 15, 2025
5f0d2a1
feat(ggml-metal): Metal impl of tri
gabe-l-hart Oct 15, 2025
426580d
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 15, 2025
ba3b8db
fix(ggml-cpu): Fix warnings from build with gcc
gabe-l-hart Oct 15, 2025
dfae909
feat(ggml-cuda): common implementation of prefix sum
gabe-l-hart Oct 16, 2025
d1f8658
feat(ggml-cuda): CUDA implementation of CUMSUM
gabe-l-hart Oct 16, 2025
5071fbd
feat(ggml-cuda): CUDA implementation of TRI
gabe-l-hart Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added: tri, cumsum. Still a mess.
Cherry-picked and edited from 7ec2df6

The original commit contained the DELTA_NET op as well which I've removed
in this cherry-picked version.

Co-Authored-By: Piotr Wilkin <[email protected]>

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Oct 10, 2025
commit 8c23c43588ed38b8fdfcfb8cf8b90f77c3570ea7
24 changes: 24 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ extern "C" {
GGML_OP_COS,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_CUMSUM,
GGML_OP_MEAN,
GGML_OP_ARGMAX,
GGML_OP_COUNT_EQUAL,
Expand Down Expand Up @@ -529,6 +530,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,
GGML_OP_TRI,

GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
Expand Down Expand Up @@ -615,6 +617,13 @@ extern "C" {
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};

enum ggml_tri_type {
GGML_TRI_TYPE_UPPER_DIAG = 0,
GGML_TRI_TYPE_UPPER = 1,
GGML_TRI_TYPE_LOWER_DIAG = 2,
GGML_TRI_TYPE_LOWER = 3
};

struct ggml_init_params {
// memory pool
size_t mem_size; // bytes
Expand Down Expand Up @@ -978,6 +987,10 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a);

// mean along rows
GGML_API struct ggml_tensor * ggml_mean(
struct ggml_context * ctx,
Expand Down Expand Up @@ -2141,6 +2154,17 @@ extern "C" {
int shift2,
int shift3);

// Make matrix into a triangular one (upper, upper + diagonal, lower or lower + diagonal) with constant value
GGML_API struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
float constant,
enum ggml_tri_type tritype);

GGML_API struct ggml_tensor * ggml_tri_keep(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type tritype);

// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
Expand Down
10 changes: 10 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_sum_rows(params, tensor);
} break;
case GGML_OP_CUMSUM:
{
ggml_compute_forward_cumsum(params, tensor);
} break;
case GGML_OP_MEAN:
{
ggml_compute_forward_mean(params, tensor);
Expand Down Expand Up @@ -1943,6 +1947,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_leaky_relu(params, tensor);
} break;
case GGML_OP_TRI:
{
ggml_compute_forward_tri(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_compute_forward_flash_attn_ext(params, tensor);
Expand Down Expand Up @@ -2153,6 +2161,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGMAX:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
{
n_tasks = 1;
} break;
Expand Down
95 changes: 95 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <float.h>
#include <algorithm>
#include <cmath>

// ggml_compute_forward_dup

Expand Down Expand Up @@ -1394,6 +1395,57 @@ void ggml_compute_forward_sum(
}
}

// ggml_compute_forward_cumsum

static void ggml_compute_forward_cumsum_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

if (params->ith != 0) {
return;
}

GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(dst->nb[0] == sizeof(float));

GGML_TENSOR_UNARY_OP_LOCALS

GGML_ASSERT(ne0 == ne00);
GGML_ASSERT(ne1 == ne01);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);

for (int64_t i3 = 0; i3 < ne03; i3++) {
for (int64_t i2 = 0; i2 < ne02; i2++) {
for (int64_t i1 = 0; i1 < ne01; i1++) {
float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
ggml_vec_cumsum_f32(ne00, dst_row, src_row);
}
}
}
}

void ggml_compute_forward_cumsum(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_cumsum_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_sum_rows

static void ggml_compute_forward_sum_rows_f32(
Expand Down Expand Up @@ -2140,6 +2192,49 @@ static void ggml_compute_forward_gelu(
}
}

// ggml_compute_tri

static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0];
float c = *((float *) &(dst->op_params[1]));
bool keep_org_val = isnan(c);

GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->ne[0] == src0->ne[1]);

GGML_TENSOR_UNARY_OP_LOCALS

const auto [ir0, ir1] = get_thread_range(params, src0);

for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);

float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 );
ggml_vec_tri_f32(ne0, i01, dst_ptr, src, keep_org_val, c, ttype);
}

}

void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_tri_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_gelu_erf

static void ggml_compute_forward_gelu_erf_f32(
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cumsum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down Expand Up @@ -85,6 +86,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
Expand Down
32 changes: 32 additions & 0 deletions ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,38 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
#endif
}

// Applies a triangular mask to the input vector 'src' and writes the result to 'dst'.
// Parameters:
// n - number of elements
// r - current row index
// dst - output array
// src - input array
// keep_org_val - if true, keep original value where mask applies; otherwise use constant 'c'
// c - constant value to use when not keeping original value
// type - type of triangular mask (lower, upper, etc.)
inline static void ggml_vec_tri_f32(const int n, const int r, float * dst, const float * src, bool keep_org_val, float c, enum ggml_tri_type type) {
for (int i = 0; i < n; ++i) {
bool cmp;
switch (type) {
case GGML_TRI_TYPE_LOWER: cmp = i < r; break;
case GGML_TRI_TYPE_LOWER_DIAG: cmp = i <= r; break;
case GGML_TRI_TYPE_UPPER: cmp = i > r; break;
case GGML_TRI_TYPE_UPPER_DIAG: cmp = i >= r; break;
}
dst[i] = cmp ? (keep_org_val ? src[i] : c) : 0.0f;
}
}

inline static void ggml_vec_cumsum_f32(const int n, float * y, const float * x) {
for (int i = 0; i < n; ++i) {
if (i == 0) {
y[i] = x[i];
} else {
y[i] = y[i - 1] + x[i];
}
}
}

inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
ggml_float sum = 0.0;
for (int i = 0; i < n; ++i) {
Expand Down
49 changes: 47 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"COS",
"SUM",
"SUM_ROWS",
"CUMSUM",
"MEAN",
"ARGMAX",
"COUNT_EQUAL",
Expand Down Expand Up @@ -990,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"LEAKY_RELU",
"TRI",

"FLASH_ATTN_EXT",
"FLASH_ATTN_BACK",
Expand Down Expand Up @@ -1019,7 +1021,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand All @@ -1039,6 +1041,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cos(x)",
"Σx",
"Σx_k",
"cumsum(x)",
"Σx/n",
"argmax(x)",
"count_equal(x)",
Expand Down Expand Up @@ -1094,6 +1097,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"leaky_relu(x)",
"tri(x)",

"flash_attn_ext(x)",
"flash_attn_back(x)",
Expand Down Expand Up @@ -1123,7 +1127,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};

static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_COUNT == 92, "GGML_OP_COUNT != 92");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -2337,6 +2341,20 @@ struct ggml_tensor * ggml_sum_rows(
return result;
}

// ggml_cumsum

struct ggml_tensor * ggml_cumsum(
struct ggml_context * ctx,
struct ggml_tensor * a) {

struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, a->ne);

result->op = GGML_OP_CUMSUM;
result->src[0] = a;

return result;
}

// ggml_mean

struct ggml_tensor * ggml_mean(
Expand Down Expand Up @@ -4968,6 +4986,33 @@ struct ggml_tensor * ggml_timestep_embedding(
return result;
}

// ggml_tri

struct ggml_tensor * ggml_tri(
struct ggml_context * ctx,
struct ggml_tensor * a,
float constant,
enum ggml_tri_type tritype) {

struct ggml_tensor * result = ggml_dup_tensor(ctx, a);

ggml_set_op_params_i32(result, 0, tritype);
ggml_set_op_params_f32(result, 1, constant);

result->op = GGML_OP_TRI;
result->src[0] = a;

return result;
}

struct ggml_tensor * ggml_tri_keep(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_tri_type tritype) {

return ggml_tri(ctx, a, nan(""), tritype);
}

// ggml_argsort

struct ggml_tensor * ggml_argsort(
Expand Down