Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
clip impl
  • Loading branch information
ngxson committed Nov 28, 2025
commit 9149ff70f11c165c16e5b961211bd2d9ef623360
1 change: 0 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10095,7 +10095,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
del bid # unused

if name.startswith("vision_tower."):
print(name)
return [(self.map_tensor_name(name), data_torch)]

return [] # skip other tensors
Expand Down
4 changes: 2 additions & 2 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
"vision_tower.blocks.{bid}.mlp.fc2", # dots.ocr
"vision_tower.blocks.{bid}.mlp.fc3", # dots.ocr
),

MODEL_TENSOR.V_ENC_FFN_GATE: (
Expand All @@ -1374,7 +1374,7 @@ class TensorNameMap:
"visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
"vision_tower.blocks.{bid}.mlp.fc3", # dots.ocr
"vision_tower.blocks.{bid}.mlp.fc2", # dots.ocr
),

MODEL_TENSOR.V_LAYER_SCALE_1: (
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ enum projector_type {
PROJECTOR_TYPE_LIGHTONOCR,
PROJECTOR_TYPE_COGVLM,
PROJECTOR_TYPE_JANUS_PRO,
PROJECTOR_TYPE_DOTS_OCR,
PROJECTOR_TYPE_UNKNOWN,
};

Expand All @@ -184,6 +185,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
{ PROJECTOR_TYPE_DOTS_OCR, "dots_ocr"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
176 changes: 144 additions & 32 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ struct clip_model {
// pixtral
ggml_tensor * token_embd_img_break = nullptr;
ggml_tensor * mm_patch_merger_w = nullptr;
ggml_tensor * mm_patch_merger_b = nullptr;

// ultravox / whisper encoder
ggml_tensor * conv1d_1_w = nullptr;
Expand Down Expand Up @@ -1839,15 +1840,7 @@ struct clip_graph {
if (model.audio_has_stack_frames()) {
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
cur = build_stacked_embeddings(cur, hparams.proj_stack_factor);
cb(cur, "after_stacked", -1);
}

Expand Down Expand Up @@ -1991,6 +1984,48 @@ struct clip_graph {
return gf;
}

ggml_cgraph * build_dots_ocr() {
// 2D input positions
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);

ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);

auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, false);
};

ggml_tensor * inp = build_inp();
ggml_tensor * cur = build_vit(
inp, n_patches,
NORM_TYPE_RMS,
hparams.ffn_op,
nullptr, // no learned pos embd
add_pos);

// dots.ocr patch merger + projector
{
GGML_ASSERT(hparams.n_merge > 0);
cur = build_norm(cur, model.mm_input_norm_w, model.mm_input_norm_b, NORM_TYPE_NORMAL, 1e-6, -1);
cur = build_stacked_embeddings(cur, hparams.n_merge * hparams.n_merge);
cb(cur, "after_patch_merger", -1);
cur = build_ffn(cur,
model.mm_0_w, model.mm_0_b,
nullptr, nullptr, // no gate
model.mm_2_w, model.mm_2_b,
FFN_GELU, -1);
cb(cur, "after_projector", -1);
}

// build the graph
ggml_build_forward_expand(gf, cur);

return gf;
}

private:
//
// utility functions
Expand Down Expand Up @@ -2065,34 +2100,69 @@ struct clip_graph {

// self-attention
{
ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}
ggml_tensor * Qcur = nullptr;
ggml_tensor * Kcur = nullptr;
ggml_tensor * Vcur = nullptr;
if (layer.qkv_w != nullptr) {
// fused qkv
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
if (layer.qkv_b != nullptr) {
cur = ggml_add(ctx0, cur, layer.qkv_b);
}

ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}
Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
d_head * ggml_element_size(cur), cur->nb[1],
/* offset */ 0);

ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}
Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
d_head * ggml_element_size(cur), cur->nb[1],
/* offset */ n_embd * ggml_element_size(cur));

if (layer.q_norm) {
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
cb(Qcur, "Qcur_norm", il);
}
Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
d_head * ggml_element_size(cur), cur->nb[1],
/* offset */ 2 * n_embd * ggml_element_size(cur));

if (layer.k_norm) {
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
cb(Kcur, "Kcur_norm", il);
}
if (layer.q_norm) {
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
cb(Qcur, "Qcur_norm", il);
}

Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
if (layer.k_norm) {
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
cb(Kcur, "Kcur_norm", il);
}

} else {
// separate q, k, v
Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
if (layer.q_b) {
Qcur = ggml_add(ctx0, Qcur, layer.q_b);
}

Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
if (layer.k_b) {
Kcur = ggml_add(ctx0, Kcur, layer.k_b);
}

Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
if (layer.v_b) {
Vcur = ggml_add(ctx0, Vcur, layer.v_b);
}

if (layer.q_norm) {
Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
cb(Qcur, "Qcur_norm", il);
}

if (layer.k_norm) {
Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
cb(Kcur, "Kcur_norm", il);
}

Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
}

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
Expand Down Expand Up @@ -2362,6 +2432,20 @@ struct clip_graph {
return cur;
}

// stack N consecutive rows into one row
ggml_tensor * build_stacked_embeddings(ggml_tensor * cur, int n_stack) {
int64_t stride = cur->ne[0] * n_stack;
int64_t padded_len = CLIP_ALIGN(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
return cur;
}

// implementation of the 2D RoPE without adding a new op in ggml
// this is not efficient (use double the memory), but works on all backends
// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
Expand Down Expand Up @@ -2524,6 +2608,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = graph.build_cogvlm();
} break;
case PROJECTOR_TYPE_DOTS_OCR:
{
res = graph.build_dots_ocr();
} break;
default:
{
res = graph.build_llava();
Expand Down Expand Up @@ -2838,6 +2926,12 @@ struct clip_model_loader {
LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
}
} break;
case PROJECTOR_TYPE_DOTS_OCR:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_LLAMA4:
{
hparams.rope_theta = 10000.0f;
Expand Down Expand Up @@ -3244,6 +3338,15 @@ struct clip_model_loader {
model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
} break;
case PROJECTOR_TYPE_DOTS_OCR:
{
model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
Expand Down Expand Up @@ -4318,6 +4421,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str

case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
case PROJECTOR_TYPE_DOTS_OCR:
{
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
clip_image_u8 resized_image;
Expand Down Expand Up @@ -4594,6 +4698,12 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
n_patches += 2; // for BOI and EOI token embeddings
} break;
case PROJECTOR_TYPE_DOTS_OCR:
{
// dynamic size
int n_stack = params.n_merge * params.n_merge;
n_patches = CLIP_ALIGN(n_patches, n_stack) / n_stack;
} break;
default:
GGML_ABORT("unsupported projector type");
}
Expand Down Expand Up @@ -4870,6 +4980,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_LIGHTONOCR:
case PROJECTOR_TYPE_DOTS_OCR:
{
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
Expand Down Expand Up @@ -5003,6 +5114,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_LIGHTONOCR:
case PROJECTOR_TYPE_DOTS_OCR:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->model.mm_3_b->ne[0];
Expand Down
5 changes: 5 additions & 0 deletions tools/mtmd/mtmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ struct mtmd_context {
img_beg = "<|im_start|>";
img_end = "<|im_end|>";

} else if (proj == PROJECTOR_TYPE_DOTS_OCR) {
// <|img|> ... (image embeddings) ... <|endofimg|>
img_beg = "<|img|>";
img_end = "<|endofimg|>";

}
}

Expand Down
Loading