Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
245f391
graph : reuse hybrid graphs
ggerganov Oct 9, 2025
638e2c2
graph : reuse recurrent graphs
ggerganov Oct 9, 2025
0b9c1ae
metal : fix mul-mm condition + fix mul-mv permuted kernels
ggerganov Oct 9, 2025
1f02d93
graph : fix reuse check for recurrent inputs
ggerganov Oct 10, 2025
00f115f
memory : move the recurrent state into the memory context
ggerganov Oct 10, 2025
2744d61
Revert "memory : move the recurrent state into the memory context"
ggerganov Oct 10, 2025
ab3f3fe
Merge branch 'gg/metal-mul-mat-fixes' into gg/graph-mamba-reuse
gabe-l-hart Oct 10, 2025
8c23c43
Added: tri, cumsum. Still a mess.
gabe-l-hart Oct 10, 2025
2a2e79c
feat(tests): Add --verbose | -v flag to test-backend-ops to print ten…
gabe-l-hart Oct 10, 2025
092f740
test: Add cumsum tests to test-backend-ops
gabe-l-hart Oct 10, 2025
6949ce7
feat(ggml-cpu): Add cumsum support for f16 and bf16
gabe-l-hart Oct 10, 2025
f8fba60
feat(ggml-cpu): Add F16 and BF16 support for tri
gabe-l-hart Oct 13, 2025
058160a
test: Add test cases for tri
gabe-l-hart Oct 13, 2025
86ce3da
chore: TODOs to loosen assertions in tri for ggml_is_contiguous
gabe-l-hart Oct 13, 2025
3a8958f
feat(ggml-metal): Initial (slow) implementation of cumsum for metal
gabe-l-hart Oct 13, 2025
cbaed86
feat(ggml-metal): Add stubs for metal tri
gabe-l-hart Oct 13, 2025
e596469
test: Use looser nmse for lower-precision types for cumsum
gabe-l-hart Oct 13, 2025
3011a6e
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 13, 2025
112d339
test: Allow multiple verbose flags to fully print tensors
gabe-l-hart Oct 15, 2025
78e137f
feat(llama-gguf): Print out the tensor type in llama-gguf r
gabe-l-hart Sep 26, 2025
e5587cb
feat(ggml-metal): Efficient implementation of cumsum for metal
gabe-l-hart Oct 15, 2025
0468b99
test: More verbose printing and better cumsum tests
gabe-l-hart Oct 15, 2025
c71e35e
fix(ggml-metal): better granularity for support bool for CUMSUM and TRI
gabe-l-hart Oct 15, 2025
5f0d2a1
feat(ggml-metal): Metal impl of tri
gabe-l-hart Oct 15, 2025
426580d
Merge remote-tracking branch 'origin/master' into Mamba2SSD
gabe-l-hart Oct 15, 2025
ba3b8db
fix(ggml-cpu): Fix warnings from build with gcc
gabe-l-hart Oct 15, 2025
dfae909
feat(ggml-cuda): common implementation of prefix sum
gabe-l-hart Oct 16, 2025
d1f8658
feat(ggml-cuda): CUDA implementation of CUMSUM
gabe-l-hart Oct 16, 2025
5071fbd
feat(ggml-cuda): CUDA implementation of TRI
gabe-l-hart Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
graph : reuse hybrid graphs
  • Loading branch information
ggerganov committed Oct 9, 2025
commit 245f39157611918962ded35a425fb7501f898f9a
41 changes: 38 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,43 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}

void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_attn->set_input(ubatch);
inp_rs->set_input(ubatch);
mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
mctx->get_attn()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);

mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);

const int64_t n_rs = mctx->get_recr()->get_n_rs();

if (inp_rs->s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
int32_t * data = (int32_t *) inp_rs->s_copy->data;

// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}

bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);

this->mctx = mctx;

bool res = true;

res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there

res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);

res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();

res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;

return res;
}

//
Expand Down Expand Up @@ -1848,7 +1883,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());

auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);

return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
Expand Down
10 changes: 8 additions & 2 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,22 +360,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
cparams(cparams),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid() = default;

void set_input(const llama_ubatch * ubatch) override;

bool can_reuse(const llm_graph_params & params) override;

std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;

llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }

const llama_cparams cparams;

const llama_memory_hybrid_context * mctx;
};

Expand Down
2 changes: 1 addition & 1 deletion src/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}

Expand Down