Skip to content
Prev Previous commit
Next Next commit
  • Loading branch information
cacaview committed Nov 30, 2025
commit 780dd783ac5744970586fa8acd7832a3ac3fc19c
2 changes: 1 addition & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5121,7 +5121,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
@ModelBase.register("KimiLinearModel", "KimiLinearForCausalLM")
class KimiLinearModel(TextModel):
"""Kimi-Linear model with hybrid MLA+KDA architecture"""
model_arch = gguf.MODEL_ARCH.KIMI
model_arch = gguf.MODEL_ARCH.KIMI_LINEAR

_experts: list[dict[str, Tensor]] | None = None

Expand Down
6 changes: 3 additions & 3 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class MODEL_ARCH(IntEnum):
MINIMAXM2 = auto()
RND1 = auto()
PANGU_EMBED = auto()
KIMI = auto() # Kimi-Linear (hybrid MLA+KDA)
KIMI_LINEAR = auto() # Kimi-Linear (hybrid MLA+KDA)


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -830,7 +830,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.COGVLM: "cogvlm",
MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.KIMI: "kimi",
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -3096,7 +3096,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.KIMI: [
MODEL_ARCH.KIMI_LINEAR: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
Expand Down
36 changes: 27 additions & 9 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,15 +1570,33 @@ class TensorNameMap:
),

# Kimi Linear KDA (using SSM_ prefix for consistency)
MODEL_TENSOR.SSM_CONV1D_Q: ("model.layers.{bid}.self_attn.q_conv1d",),
MODEL_TENSOR.SSM_CONV1D_K: ("model.layers.{bid}.self_attn.k_conv1d",),
MODEL_TENSOR.SSM_CONV1D_V: ("model.layers.{bid}.self_attn.v_conv1d",),
MODEL_TENSOR.SSM_F_A: ("model.layers.{bid}.self_attn.f_a_proj",),
MODEL_TENSOR.SSM_F_B: ("model.layers.{bid}.self_attn.f_b_proj",),
MODEL_TENSOR.SSM_BETA: ("model.layers.{bid}.self_attn.b_proj",),
MODEL_TENSOR.SSM_A_LOG: ("model.layers.{bid}.self_attn.A_log",),
MODEL_TENSOR.SSM_G_A: ("model.layers.{bid}.self_attn.g_a_proj",),
MODEL_TENSOR.SSM_G_B: ("model.layers.{bid}.self_attn.g_b_proj",),
MODEL_TENSOR.SSM_CONV1D_Q: (
"model.layers.{bid}.self_attn.q_conv1d",
),
MODEL_TENSOR.SSM_CONV1D_K: (
"model.layers.{bid}.self_attn.k_conv1d",
),
MODEL_TENSOR.SSM_CONV1D_V: (
"model.layers.{bid}.self_attn.v_conv1d",
),
MODEL_TENSOR.SSM_F_A: (
"model.layers.{bid}.self_attn.f_a_proj",
),
MODEL_TENSOR.SSM_F_B: (
"model.layers.{bid}.self_attn.f_b_proj",
),
MODEL_TENSOR.SSM_BETA: (
"model.layers.{bid}.self_attn.b_proj",
),
MODEL_TENSOR.SSM_A_LOG: (
"model.layers.{bid}.self_attn.A_log",
),
MODEL_TENSOR.SSM_G_A: (
"model.layers.{bid}.self_attn.g_a_proj",
),
MODEL_TENSOR.SSM_G_B: (
"model.layers.{bid}.self_attn.g_b_proj",
),
MODEL_TENSOR.SSM_DT_B: (
"model.layers.{bid}.self_attn.dt_bias",
),
Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ add_library(llama
models/internlm2.cpp
models/jais.cpp
models/jamba.cpp
models/kimi.cpp
models/kimi-linear.cpp
models/lfm2.cpp
models/llada-moe.cpp
models/llada.cpp
Expand Down
8 changes: 4 additions & 4 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_KIMI, "kimi" },
{ LLM_ARCH_KIMI_LINEAR, "kimi-linear" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -2494,7 +2494,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
},
},
{
LLM_ARCH_KIMI,
LLM_ARCH_KIMI_LINEAR,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
Expand Down Expand Up @@ -2833,7 +2833,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
case LLM_ARCH_KIMI: // KDA layers use delta attention with recurrent state
case LLM_ARCH_KIMI_LINEAR: // KDA layers use delta attention with recurrent state
return true;
default:
return false;
Expand All @@ -2852,7 +2852,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_QWEN3NEXT:
// Kimi: Currently using recurrent-only mode since MLA doesn't use KV cache
// TODO: Enable hybrid when MLA KV caching is implemented
// case LLM_ARCH_KIMI:
// case LLM_ARCH_KIMI_LINEAR:
return true;
default:
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ enum llm_arch {
LLM_ARCH_COGVLM,
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_KIMI,
LLM_ARCH_KIMI_LINEAR,
LLM_ARCH_UNKNOWN,
};

Expand Down
2 changes: 1 addition & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1387,7 +1387,7 @@ void llama_context::output_reorder() {
//

uint32_t llama_context::graph_max_nodes() const {
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI) {
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) {
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
}
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
Expand Down
10 changes: 5 additions & 5 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2247,7 +2247,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_KIMI:
case LLM_ARCH_KIMI_LINEAR:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false);
Expand Down Expand Up @@ -6406,7 +6406,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0);
}
} break;
case LLM_ARCH_KIMI:
case LLM_ARCH_KIMI_LINEAR:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

Expand Down Expand Up @@ -7712,9 +7712,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_qwen3next>(*this, params);
} break;
case LLM_ARCH_KIMI:
case LLM_ARCH_KIMI_LINEAR:
{
llm = std::make_unique<llm_build_kimi>(*this, params);
llm = std::make_unique<llm_build_kimi_linear>(*this, params);
} break;
default:
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -7871,7 +7871,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_KIMI:
case LLM_ARCH_KIMI_LINEAR:
case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GLM4:
Expand Down
2 changes: 1 addition & 1 deletion src/llama-quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// sanity checks for models that have attention layers
// Skip this check for Kimi models which have hybrid KDA+MLA architecture
// (only MLA layers have attn_kv_b weights, KDA layers don't)
if (qs.n_attention_wv != 0 && !is_clip_model && model.arch != LLM_ARCH_KIMI)
if (qs.n_attention_wv != 0 && !is_clip_model && model.arch != LLM_ARCH_KIMI_LINEAR)
{
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads
Expand Down
2 changes: 1 addition & 1 deletion src/models/kimi.cpp → src/models/kimi-linear.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to kimi-linear.cpp

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "models.h"

llm_build_kimi::llm_build_kimi(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params), model(model) {
llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params), model(model) {
ggml_tensor * cur;
ggml_tensor * inpL;

Expand Down
4 changes: 2 additions & 2 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ struct llm_build_jamba : public llm_graph_context_mamba {
llm_build_jamba(const llama_model & model, const llm_graph_params & params);
};

struct llm_build_kimi : public llm_graph_context_mamba {
llm_build_kimi(const llama_model & model, const llm_graph_params & params);
struct llm_build_kimi_linear : public llm_graph_context_mamba {
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
private:
const llama_model & model;
};
Expand Down