Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat(kimi): add KDA CUDA kernel and optimize recurrence implementation
- Add KDA (Kimi Delta Attention) CUDA kernel (kda-scan.cu)
- Fix recurrence order: decay first, then retrieval
- Verify CPU/CUDA implementation consistency
- Support head_dim=128, L2 normalization for Q/K
  • Loading branch information
cacaview committed Nov 29, 2025
commit 6b20da1d3c19e74bafe46a08c5e00836dfc6217d
8 changes: 4 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,10 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
# KDA conv weights should be F32
gguf.MODEL_TENSOR.KDA_Q_CONV,
gguf.MODEL_TENSOR.KDA_K_CONV,
gguf.MODEL_TENSOR.KDA_V_CONV,
# Kimi KDA conv weights should be F32
gguf.MODEL_TENSOR.SSM_CONV1D_Q,
gguf.MODEL_TENSOR.SSM_CONV1D_K,
gguf.MODEL_TENSOR.SSM_CONV1D_V,
)
)
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
Expand Down
14 changes: 11 additions & 3 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8953,7 +8953,7 @@ static void ggml_compute_forward_kda_scan_f32(
float * hk_buf = (float *) malloc(head_dim * sizeof(float));

static int debug_count = 0;
bool do_debug = false; // (ith == 0 && debug_count++ < 3);
bool do_debug = false; // (ith == 0 && debug_count++ < 20);

for (int i3 = 0; i3 < n_seqs; ++i3) {
// Get initial hidden state for this sequence
Expand Down Expand Up @@ -9021,6 +9021,9 @@ static void ggml_compute_forward_kda_scan_f32(
k[i] = k_raw[i] / k_norm;
}

// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// Note: Apply decay first, then compute retrieval and update

// Step 1: Apply decay to h first: h = h * exp(g)
for (int i = 0; i < head_dim; ++i) {
const float exp_gi = expf(g[i]);
Expand Down Expand Up @@ -9060,8 +9063,13 @@ static void ggml_compute_forward_kda_scan_f32(

// Debug output
if (do_debug && ih == 0 && it == 0 && i3 == 0) {
fprintf(stderr, "DEBUG KDA output: y[0]=%f, y[1]=%f, h[0]=%f, h[1]=%f\n",
y[0], y[1], h[0], h[1]);
// Find max abs value in h for stability check
float h_max = 0.0f;
for (int i = 0; i < head_dim * head_dim; i++) {
if (fabsf(h[i]) > h_max) h_max = fabsf(h[i]);
}
fprintf(stderr, "DEBUG KDA: y[0]=%.6f, h_max=%.6f, exp(g[0])=%.6f\n",
y[0], h_max, expf(g[0]));
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/kda-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
Expand Down Expand Up @@ -2691,6 +2692,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_KDA_SCAN:
ggml_cuda_op_kda_scan(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
Expand Down Expand Up @@ -4202,6 +4206,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
}
}
case GGML_OP_KDA_SCAN: {
// KDA scan kernel supports head_dim 64 or 128
const int64_t head_dim = op->src[0]->ne[0];
return head_dim == 64 || head_dim == 128;
}
case GGML_OP_SSM_CONV: {
// assumes d_inner % threads == 0
return op->src[0]->ne[1] % 128 == 0;
Expand Down
209 changes: 209 additions & 0 deletions ggml/src/ggml-cuda/kda-scan.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#include "kda-scan.cuh"

// KDA (Kimi Delta Attention) scan CUDA kernel
// Recurrence:
// h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// o[t] = q[t]^T @ h[t]
//
// This kernel uses global memory for the hidden state to avoid shared memory limits.
// Each block processes one head for one sequence.

__global__ void kda_scan_f32_kernel(
const float * __restrict__ src0, // h: [head_dim, head_dim, n_head, n_seqs+]
const float * __restrict__ src1, // q: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src2, // k: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src3, // v: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src4, // g: [head_dim, n_head, n_seq_tokens, n_seqs]
const float * __restrict__ src5, // beta: [n_head, n_seq_tokens, n_seqs]
const int32_t * __restrict__ src6, // ids: [n_seqs]
float * __restrict__ dst,
const int64_t head_dim,
const int64_t n_head,
const int64_t n_seq_tokens,
const int64_t n_seqs,
const int64_t y_off) // offset to state output in dst (in floats)
{
// Each block handles one head for one sequence
const int seq_idx = blockIdx.x / n_head;
const int head_idx = blockIdx.x % n_head;
const int tid = threadIdx.x;
const int n_threads = blockDim.x;

if (seq_idx >= n_seqs || head_idx >= n_head) return;

// Get sequence ID for initial state
const int src_seq = src6[seq_idx];

// Shared memory for temporary buffers
extern __shared__ float smem[];
float * hk_buf = smem; // [head_dim] - h @ k buffer
float * q_norm = smem + head_dim; // [head_dim] - normalized q
float * k_norm = q_norm + head_dim; // [head_dim] - normalized k
float * warp_sums = k_norm + head_dim; // [64] - for reductions

// Pointers to input/output data for this head
const int64_t h_stride_head = head_dim * head_dim;
const int64_t h_stride_seq = h_stride_head * n_head;
const int64_t qkv_stride_head = head_dim;
const int64_t qkv_stride_token = head_dim * n_head;
const int64_t qkv_stride_seq = qkv_stride_token * n_seq_tokens;
const int64_t beta_stride_token = n_head;
const int64_t beta_stride_seq = beta_stride_token * n_seq_tokens;

const float * h_in = src0 + src_seq * h_stride_seq + head_idx * h_stride_head;
float * h_out = dst + y_off + seq_idx * h_stride_seq + head_idx * h_stride_head;
float * y_out = dst + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;

// Copy initial state to output (we'll update in place)
for (int i = tid; i < head_dim * head_dim; i += n_threads) {
float val = h_in[i];
if (!isfinite(val) || fabsf(val) > 1e6f) {
val = 0.0f;
}
h_out[i] = val;
}
__syncthreads();

const float scale = 1.0f / sqrtf((float)head_dim);

// Process each token sequentially
for (int t = 0; t < n_seq_tokens; ++t) {
const float * q_raw = src1 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float * k_raw = src2 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float * v = src3 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float * g = src4 + t * qkv_stride_token + seq_idx * qkv_stride_seq + head_idx * qkv_stride_head;
const float beta = src5[t * beta_stride_token + seq_idx * beta_stride_seq + head_idx];
float * y = y_out + t * qkv_stride_token;

// Step 1: L2 normalize q and k
float q_sq_sum = 0.0f, k_sq_sum = 0.0f;
for (int i = tid; i < head_dim; i += n_threads) {
q_sq_sum += q_raw[i] * q_raw[i];
k_sq_sum += k_raw[i] * k_raw[i];
}

// Warp reduction
for (int offset = warpSize/2; offset > 0; offset /= 2) {
q_sq_sum += __shfl_down_sync(0xffffffff, q_sq_sum, offset);
k_sq_sum += __shfl_down_sync(0xffffffff, k_sq_sum, offset);
}

// Cross-warp reduction
int warp_id = tid / warpSize;
int lane_id = tid % warpSize;
if (lane_id == 0 && warp_id < 32) {
warp_sums[warp_id] = q_sq_sum;
warp_sums[32 + warp_id] = k_sq_sum;
}
__syncthreads();

if (tid == 0) {
float total_q = 0.0f, total_k = 0.0f;
for (int i = 0; i < (n_threads + warpSize - 1) / warpSize; ++i) {
total_q += warp_sums[i];
total_k += warp_sums[32 + i];
}
warp_sums[0] = rsqrtf(total_q + 1e-6f) * scale;
warp_sums[1] = rsqrtf(total_k + 1e-6f);
}
__syncthreads();

float q_norm_factor = warp_sums[0];
float k_norm_factor = warp_sums[1];

// Store normalized q and k
for (int i = tid; i < head_dim; i += n_threads) {
q_norm[i] = q_raw[i] * q_norm_factor;
k_norm[i] = k_raw[i] * k_norm_factor;
}
__syncthreads();

// KDA recurrence: h[t] = exp(g[t]) * h[t-1] + k[t]^T * (beta[t] * (v[t] - h[t-1] @ k[t]))
// Apply decay first, then compute retrieval and update

// Step 2: Apply decay to h: h = h * exp(g)
for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) {
int i = idx / head_dim;
float exp_gi = expf(g[i]);
h_out[idx] *= exp_gi;
}
__syncthreads();

// Step 3: Compute h^T @ k -> hk_buf
for (int j = tid; j < head_dim; j += n_threads) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h_out[i * head_dim + j] * k_norm[i];
}
hk_buf[j] = sum;
}
__syncthreads();

// Step 4: Update h: h = h + outer(k, beta * (v - hk))
for (int idx = tid; idx < head_dim * head_dim; idx += n_threads) {
int i = idx / head_dim;
int j = idx % head_dim;
float delta_j = beta * (v[j] - hk_buf[j]);
h_out[idx] += k_norm[i] * delta_j;
}
__syncthreads();

// Step 5: Compute output y = h^T @ q
for (int j = tid; j < head_dim; j += n_threads) {
float sum = 0.0f;
for (int i = 0; i < head_dim; ++i) {
sum += h_out[i * head_dim + j] * q_norm[i];
}
y[j] = sum;
}
__syncthreads();
}
}

void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // h
const ggml_tensor * src1 = dst->src[1]; // q
const ggml_tensor * src2 = dst->src[2]; // k
const ggml_tensor * src3 = dst->src[3]; // v
const ggml_tensor * src4 = dst->src[4]; // g
const ggml_tensor * src5 = dst->src[5]; // beta
const ggml_tensor * src6 = dst->src[6]; // ids

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src3->type == GGML_TYPE_F32);
GGML_ASSERT(src4->type == GGML_TYPE_F32);
GGML_ASSERT(src5->type == GGML_TYPE_F32);
GGML_ASSERT(src6->type == GGML_TYPE_I32);

const int64_t head_dim = src0->ne[0];
const int64_t n_head = src1->ne[1];
const int64_t n_seq_tokens = src1->ne[2];
const int64_t n_seqs = src1->ne[3];

// Output offset for hidden state (after attention output) - in floats
const int64_t y_off = ggml_nelements(src1);

const float * h_d = (const float *)src0->data;
const float * q_d = (const float *)src1->data;
const float * k_d = (const float *)src2->data;
const float * v_d = (const float *)src3->data;
const float * g_d = (const float *)src4->data;
const float * beta_d = (const float *)src5->data;
const int32_t * ids_d = (const int32_t *)src6->data;
float * dst_d = (float *)dst->data;

cudaStream_t stream = ctx.stream();

// Launch kernel: one block per (sequence, head) pair
const int n_blocks = n_seqs * n_head;
const int n_threads = 128;

// Shared memory: hk_buf[head_dim] + q_norm[head_dim] + k_norm[head_dim] + warp_sums[64]
size_t smem_size = (3 * head_dim + 64) * sizeof(float);

kda_scan_f32_kernel<<<n_blocks, n_threads, smem_size, stream>>>(
h_d, q_d, k_d, v_d, g_d, beta_d, ids_d, dst_d,
head_dim, n_head, n_seq_tokens, n_seqs, y_off);
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda/kda-scan.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_kda_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
68 changes: 33 additions & 35 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,18 +703,17 @@ class MODEL_TENSOR(IntEnum):
A_MMPROJ_FC = auto()
A_MM_NORM_PRE = auto()
A_MM_NORM_MID = auto()
# Kimi Linear
KDA_Q_CONV = auto()
KDA_K_CONV = auto()
KDA_V_CONV = auto()
KDA_F_A = auto()
KDA_F_B = auto()
KDA_B = auto()
KDA_A_LOG = auto()
KDA_G_A = auto()
KDA_G_B = auto()
KDA_O_NORM = auto()
KDA_DT_BIAS = auto()
# Kimi Linear KDA (using SSM_ prefix for consistency)
SSM_CONV1D_Q = auto()
SSM_CONV1D_K = auto()
SSM_CONV1D_V = auto()
SSM_F_A = auto()
SSM_F_B = auto()
SSM_BETA = auto()
SSM_A_LOG = auto()
SSM_G_A = auto()
SSM_G_B = auto()
SSM_DT_B = auto()
Comment on lines +706 to +716
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move them together with the other SSM_* ones.

# nextn/mtp
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
Expand Down Expand Up @@ -1087,18 +1086,17 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
# Kimi Linear
MODEL_TENSOR.KDA_Q_CONV: "blk.{bid}.kda_q_conv",
MODEL_TENSOR.KDA_K_CONV: "blk.{bid}.kda_k_conv",
MODEL_TENSOR.KDA_V_CONV: "blk.{bid}.kda_v_conv",
MODEL_TENSOR.KDA_F_A: "blk.{bid}.kda_f_a",
MODEL_TENSOR.KDA_F_B: "blk.{bid}.kda_f_b",
MODEL_TENSOR.KDA_B: "blk.{bid}.kda_b",
MODEL_TENSOR.KDA_A_LOG: "blk.{bid}.kda_a_log",
MODEL_TENSOR.KDA_G_A: "blk.{bid}.kda_g_a",
MODEL_TENSOR.KDA_G_B: "blk.{bid}.kda_g_b",
MODEL_TENSOR.KDA_O_NORM: "blk.{bid}.kda_o_norm",
MODEL_TENSOR.KDA_DT_BIAS: "blk.{bid}.kda_dt_bias",
# Kimi Linear KDA (using SSM_ prefix for consistency)
MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q",
MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k",
MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v",
MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a",
MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b",
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta",
MODEL_TENSOR.SSM_A_LOG: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a",
MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b",
MODEL_TENSOR.SSM_DT_B: "blk.{bid}.ssm_dt",
Comment on lines +1089 to +1099
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move them.

# NextN/MTP
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
Expand Down Expand Up @@ -3121,17 +3119,17 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.KDA_Q_CONV,
MODEL_TENSOR.KDA_K_CONV,
MODEL_TENSOR.KDA_V_CONV,
MODEL_TENSOR.KDA_F_A,
MODEL_TENSOR.KDA_F_B,
MODEL_TENSOR.KDA_B,
MODEL_TENSOR.KDA_A_LOG,
MODEL_TENSOR.KDA_G_A,
MODEL_TENSOR.KDA_G_B,
MODEL_TENSOR.KDA_O_NORM,
MODEL_TENSOR.KDA_DT_BIAS,
MODEL_TENSOR.SSM_CONV1D_Q,
MODEL_TENSOR.SSM_CONV1D_K,
MODEL_TENSOR.SSM_CONV1D_V,
MODEL_TENSOR.SSM_F_A,
MODEL_TENSOR.SSM_F_B,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_A_LOG,
MODEL_TENSOR.SSM_G_A,
MODEL_TENSOR.SSM_G_B,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_DT_B,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
Expand Down
Loading