Skip to content
Prev Previous commit
Next Next commit
feat: complete Kimi-Linear inference implementation
- Implement KDA layer (linear attention with gates and decay)
- Implement MLA layer (multi-head latent attention with KV compression)
- Support MoE FFN with shared experts
- Add TikToken tokenizer support for Kimi models
- Fix vocab loading for large vocabularies
- Model loads and runs inference (27 layers, 603 tensors)
  • Loading branch information
cacaview committed Nov 28, 2025
commit 0e047846400b176fab231e58671d7dccc2711396
260 changes: 157 additions & 103 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2722,58 +2722,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [] # skip other tensors


@ModelBase.register("KimiLinearForCausalLM")
class KimiLinearModel(ModelBase):
model_arch = gguf.MODEL_ARCH.KIMI

def set_gguf_parameters(self):
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["qk_rope_head_dim"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])

linear_attn = self.hparams.get("linear_attn_config", {})
if linear_attn:
self.gguf_writer.add_ssm_conv_kernel(linear_attn.get("short_conv_kernel_size", 4))
# Add other Kimi params as generic KV if needed or extend GGUFWriter
# For now we rely on conv_kernel being enough for the conv op

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "gate_up_proj" in name:
# shape: (2 * intermediate_size, hidden_size)
# split along dim 0. Assuming [gate; up]
out_dim = data_torch.shape[0]
mid = out_dim // 2
w1 = data_torch[:mid, :] # gate
w3 = data_torch[mid:, :] # up

# Map directly using the split names which should map to FFN_GATE and FFN_UP
# We need to construct the original names that map_tensor_name expects for mapping
# Or we can manual map if we know the logic.
# But modify_tensors usually returns mapped names.

# tensor_mapping.py:
# FFN_GATE: "mlp.gate_proj" (standard llama)
# FFN_UP: "mlp.up_proj"

# name is something like "model.layers.0.mlp.gate_up_proj.weight"
name_gate = name.replace("gate_up_proj", "gate_proj")
name_up = name.replace("gate_up_proj", "up_proj")

return [
(self.map_tensor_name(name_gate), w1),
(self.map_tensor_name(name_up), w3)
]

# Handle 1x1xHx1 tensors like A_log
if "A_log" in name:
data_torch = data_torch.squeeze()

return [(self.map_tensor_name(name), data_torch)]
# KimiLinearModel is defined later in this file (line ~5140) as a TextModel subclass
# This old definition has been removed to avoid conflicts


@ModelBase.register(
Expand Down Expand Up @@ -5162,12 +5112,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), k),
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v),
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
]
]
else:
return [(self.map_tensor_name(name), data_torch)]

else:
return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("KimiLinearModel", "KimiLinearForCausalLM")
class KimiLinearModel(TextModel):
"""Kimi-Linear model with hybrid MLA+KDA architecture"""
model_arch = gguf.MODEL_ARCH.KIMI
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_arch = gguf.MODEL_ARCH.KIMI
model_arch = gguf.MODEL_ARCH.KIMI_LINEAR


_experts: list[dict[str, Tensor]] | None = None
Expand All @@ -5191,17 +5140,36 @@ def set_gguf_parameters(self):
# KDA & MLA params
# Assuming these keys exist in config.json for Kimi models
if "ssm_d_conv" in self.hparams:
self.gguf_writer.add_uint32(gguf.KEY_SSM_CONV_KERNEL, self.hparams["ssm_d_conv"])
self.gguf_writer.add_ssm_conv_kernel(self.hparams["ssm_d_conv"])

# MLA params - use add_* methods that handle arch substitution
# Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv)
q_lora_rank = self.hparams.get("q_lora_rank", self.hparams.get("n_lora_q"))
kv_lora_rank = self.hparams.get("kv_lora_rank", self.hparams.get("n_lora_kv"))

# MLA params
if "n_lora_q" in self.hparams:
self.gguf_writer.add_uint32(gguf.KEY_ATTENTION_Q_LORA_RANK, self.hparams["n_lora_q"])
if "n_lora_kv" in self.hparams:
self.gguf_writer.add_uint32(gguf.KEY_ATTENTION_KV_LORA_RANK, self.hparams["n_lora_kv"])
if q_lora_rank is not None:
self.gguf_writer.add_q_lora_rank(q_lora_rank)
if kv_lora_rank is not None:
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)

# MLA head dimensions
# Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim
qk_nope_head_dim = self.hparams.get("qk_nope_head_dim")
qk_rope_head_dim = self.hparams.get("qk_rope_head_dim", self.hparams.get("n_rot"))
v_head_dim = self.hparams.get("v_head_dim")

# Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
if "n_embd_head_k_mla" in self.hparams:
self.gguf_writer.add_uint32(gguf.KEY_ATTENTION_KEY_LENGTH_MLA, self.hparams["n_embd_head_k_mla"])
self.gguf_writer.add_key_length_mla(self.hparams["n_embd_head_k_mla"])
elif qk_nope_head_dim is not None and qk_rope_head_dim is not None:
n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
self.gguf_writer.add_key_length_mla(n_embd_head_k_mla)

# n_embd_head_v_mla = v_head_dim
if "n_embd_head_v_mla" in self.hparams:
self.gguf_writer.add_uint32(gguf.KEY_ATTENTION_VALUE_LENGTH_MLA, self.hparams["n_embd_head_v_mla"])
self.gguf_writer.add_value_length_mla(self.hparams["n_embd_head_v_mla"])
elif v_head_dim is not None:
self.gguf_writer.add_value_length_mla(v_head_dim)

# Rotation
# Kimi likely uses n_rot
Expand All @@ -5221,6 +5189,94 @@ def set_gguf_parameters(self):
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)

def set_vocab(self):
# Kimi uses TikToken tokenizer - load via transformers
from transformers import AutoTokenizer

dir_model = self.dir_model
vocab_size = self.hparams["vocab_size"]

logger.info(f"Loading TikToken tokenizer from {dir_model}")
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)

tokens: list[str] = []
toktypes: list[int] = []

# Get tokenizer pre string
tokpre = self.get_vocab_base_pre(tokenizer)

# Build vocab from tokenizer
merges = []
vocab = {}

# TikToken stores vocab in mergeable_ranks
if hasattr(tokenizer, 'mergeable_ranks'):
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[self._token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
# Build merges
merged = self._bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(self._token_bytes_to_string, merged)))
else:
# Fallback: get vocab directly
vocab = {tok: idx for tok, idx in tokenizer.get_vocab().items()}

# Get special tokens
added_vocab = {}
if hasattr(tokenizer, 'special_tokens'):
added_vocab = tokenizer.special_tokens
elif hasattr(tokenizer, 'added_tokens_encoder'):
added_vocab = tokenizer.added_tokens_encoder

# Combine vocab
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
elif i in added_vocab.values() if added_vocab else False:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.CONTROL)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.merges = merges
special_vocab.add_to_gguf(self.gguf_writer)
logger.info(f"Loaded {len(tokens)} tokens, {len(merges)} merges")

@staticmethod
def _token_bytes_to_string(b: bytes) -> str:
"""Convert bytes to string representation for tokenizer"""
return ''.join([chr(byte) if byte < 128 else f'<0x{byte:02X}>' for byte in b])

@staticmethod
def _bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
"""BPE tokenization for merges extraction"""
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
Expand All @@ -5229,6 +5285,29 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
logger.info(f"Processing {name}: shape before = {tuple(data_torch.shape)}")

# GGUF writer automatically reverses tensor dimensions!
# For Kimi dummy models: tensors are created in llama.cpp format
# Since GGUF will reverse them, we DON'T transpose (no-op means they'll be reversed once by GGUF)
# But wait - that means we need NO transpose at all!
# Let's trace through an example:
# HF model: q_a_proj = [128, 32] (llama.cpp format)
# If we DON'T transpose: GGUF writer reverses → [32, 128] in file ✗
# If we DO transpose: [32, 128] → GGUF writer reverses → [128, 32] in file ✓
#
# So actually, for Kimi we should transpose ALL weights to cancel out GGUF's reversal!

# GGUF dimension handling:
# When numpy array (rows, cols) is written to GGUF, the reader reports (cols, rows).
# llama.cpp create_tensor specifies {n_embd, n_vocab} = {2304, 163840}
# HF embedding is [vocab, n_embd] = [163840, 2304]
# If we DON'T transpose: write (163840, 2304) → GGUF shows [2304, 163840] ✓
# If we DO transpose: write (2304, 163840) → GGUF shows [163840, 2304] ✗
# So: NO transpose needed for embeddings!
if len(data_torch.shape) == 2 and "weight" in name:
logger.info(f"Keeping {name} as-is: {tuple(data_torch.shape)} (GGUF will show reversed)")

# Kimi specific bias
if name.endswith("block_sparse_moe.gate.e_score_correction_bias"):
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, bid)
Expand Down Expand Up @@ -5262,49 +5341,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
tensors.append((new_name, data_torch))
return tensors
return []

return [(self.map_tensor_name(name), data_torch)]

def get_vocab_base(self) -> tuple[list[str], list[int], str]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)

# Call parent implementation with the loaded tokenizer
tokens: list[str] = []
toktypes: list[int] = []

reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()}
added_vocab = tokenizer.get_added_vocab()
mapped_name = self.map_tensor_name(name)
logger.info(f"Returning {mapped_name}: shape after = {tuple(data_torch.shape)}")
return [(mapped_name, data_torch)]

for i in range(len(tokenizer)):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)

tokpre = self.get_vocab_base_pre(tokenizer)

return tokens, toktypes, tokpre

def set_vocab(self):
try:
self._set_vocab_gpt2()
except Exception as e:
logger.warning(f"Failed to load tokenizer with GPT2 method: {e}")
logger.warning("Attempting to use sentencepiece tokenizer")
try:
self._set_vocab_sentencepiece()
except Exception as e2:
logger.error(f"Failed to load tokenizer: {e2}")
raise
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
# This method is not used when set_vocab is overridden
# But adding it for completeness in case it's called elsewhere
logger.warning("get_vocab_base called, but set_vocab is already overridden")
vocab_size = self.hparams.get("vocab_size", 100)
tokens = [f"<token_{i}>" for i in range(vocab_size)]
tokens[0] = "<unk>"
tokens[1] = "<s>"
tokens[2] = "</s>"
toktypes = [gguf.TokenType.NORMAL] * vocab_size
return tokens, toktypes, "gpt-2"

# Note: set_vocab() is defined earlier in this class (around line 5144)


@ModelBase.register("InternLM3ForCausalLM")
Expand Down
6 changes: 5 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class MODEL_ARCH(IntEnum):
MINIMAXM2 = auto()
RND1 = auto()
PANGU_EMBED = auto()
KIMI = auto()
KIMI = auto() # Kimi-Linear (hybrid MLA+KDA)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
KIMI = auto() # Kimi-Linear (hybrid MLA+KDA)
KIMI_LINEAR = auto()



class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -3440,6 +3440,10 @@ class VisionProjectorType:
KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV
KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS
KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
KEY_ATTENTION_Q_LORA_RANK = Keys.Attention.Q_LORA_RANK
KEY_ATTENTION_KV_LORA_RANK = Keys.Attention.KV_LORA_RANK
KEY_ATTENTION_KEY_LENGTH_MLA = Keys.Attention.KEY_LENGTH_MLA
KEY_ATTENTION_VALUE_LENGTH_MLA = Keys.Attention.VALUE_LENGTH_MLA
Comment on lines +3441 to +3444
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
KEY_ATTENTION_Q_LORA_RANK = Keys.Attention.Q_LORA_RANK
KEY_ATTENTION_KV_LORA_RANK = Keys.Attention.KV_LORA_RANK
KEY_ATTENTION_KEY_LENGTH_MLA = Keys.Attention.KEY_LENGTH_MLA
KEY_ATTENTION_VALUE_LENGTH_MLA = Keys.Attention.VALUE_LENGTH_MLA

These are old aliases.


# RoPE
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ add_library(llama
models/internlm2.cpp
models/jais.cpp
models/jamba.cpp
models/kimi.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
models/kimi.cpp
models/kimi-linear.cpp

models/lfm2.cpp
models/llada-moe.cpp
models/llada.cpp
Expand Down
Loading