Skip to content

Commit ff29a86

Browse files
committed
Add safetensors support
So we can load these natively just like gguf Signed-off-by: Eric Curtin <[email protected]>
1 parent 03914c7 commit ff29a86

11 files changed

+1663
-0
lines changed

src/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ add_library(llama
3232
llama-quant.cpp
3333
llama-sampling.cpp
3434
llama-vocab.cpp
35+
llama-safetensors.cpp
36+
llama-hf-config.cpp
37+
llama-safetensors-loader.cpp
38+
llama-safetensors-types.cpp
3539
unicode-data.cpp
3640
unicode.cpp
3741
unicode.h

src/llama-hf-config.cpp

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#include "llama-hf-config.h"
2+
3+
#include <fstream>
4+
#include "../vendor/nlohmann/json.hpp"
5+
6+
using json = nlohmann::json;
7+
8+
bool hf_config::load_from_file(const std::string & config_path) {
9+
std::ifstream f(config_path);
10+
if (!f.is_open()) {
11+
error_msg = "Failed to open config file: " + config_path;
12+
return false;
13+
}
14+
15+
try {
16+
config = std::make_unique<json>();
17+
f >> *config;
18+
} catch (const std::exception & e) {
19+
error_msg = std::string("Failed to parse config JSON: ") + e.what();
20+
return false;
21+
}
22+
23+
return true;
24+
}
25+
26+
bool hf_config::load_from_string(const std::string & json_str) {
27+
try {
28+
config = std::make_unique<json>(json::parse(json_str));
29+
} catch (const std::exception & e) {
30+
error_msg = std::string("Failed to parse config JSON: ") + e.what();
31+
return false;
32+
}
33+
34+
return true;
35+
}
36+
37+
std::string hf_config::get_architecture() const {
38+
if (!config) {
39+
return "";
40+
}
41+
42+
// Check for architectures array (most common)
43+
if (config->contains("architectures") && (*config)["architectures"].is_array()) {
44+
const auto & archs = (*config)["architectures"];
45+
if (!archs.empty() && archs[0].is_string()) {
46+
return archs[0].get<std::string>();
47+
}
48+
}
49+
50+
// Check text_config (for multimodal models)
51+
if (config->contains("text_config") && (*config)["text_config"].is_object()) {
52+
const auto & text_config = (*config)["text_config"];
53+
if (text_config.contains("architectures") && text_config["architectures"].is_array()) {
54+
const auto & archs = text_config["architectures"];
55+
if (!archs.empty() && archs[0].is_string()) {
56+
return archs[0].get<std::string>();
57+
}
58+
}
59+
}
60+
61+
// Check for ssm_cfg (Mamba models)
62+
if (config->contains("ssm_cfg") && (*config)["ssm_cfg"].is_object()) {
63+
const auto & ssm_cfg = (*config)["ssm_cfg"];
64+
if (ssm_cfg.contains("layer") && ssm_cfg["layer"].is_string()) {
65+
return ssm_cfg["layer"].get<std::string>() + "ForCausalLM";
66+
}
67+
}
68+
69+
return "";
70+
}
71+
72+
template<typename T>
73+
bool hf_config::get_value_with_fallback(const std::string & key, T & out) const {
74+
if (!config) {
75+
return false;
76+
}
77+
78+
// First try root level
79+
if (config->contains(key)) {
80+
try {
81+
out = (*config)[key].get<T>();
82+
return true;
83+
} catch (const std::exception &) {
84+
return false;
85+
}
86+
}
87+
88+
// Try text_config (for multimodal models)
89+
if (config->contains("text_config") && (*config)["text_config"].is_object()) {
90+
const auto & text_config = (*config)["text_config"];
91+
if (text_config.contains(key)) {
92+
try {
93+
out = text_config[key].get<T>();
94+
return true;
95+
} catch (const std::exception &) {
96+
return false;
97+
}
98+
}
99+
}
100+
101+
return false;
102+
}
103+
104+
bool hf_config::get_int(const std::string & key, int64_t & out) const {
105+
return get_value_with_fallback(key, out);
106+
}
107+
108+
bool hf_config::get_float(const std::string & key, double & out) const {
109+
return get_value_with_fallback(key, out);
110+
}
111+
112+
bool hf_config::get_string(const std::string & key, std::string & out) const {
113+
return get_value_with_fallback(key, out);
114+
}
115+
116+
bool hf_config::get_bool(const std::string & key, bool & out) const {
117+
return get_value_with_fallback(key, out);
118+
}
119+
120+
bool hf_config::has_key(const std::string & key) const {
121+
if (!config) {
122+
return false;
123+
}
124+
125+
if (config->contains(key)) {
126+
return true;
127+
}
128+
129+
// Check text_config
130+
if (config->contains("text_config") && (*config)["text_config"].is_object()) {
131+
return (*config)["text_config"].contains(key);
132+
}
133+
134+
return false;
135+
}
136+
137+
const nlohmann::json * hf_config::get_json() const {
138+
return config.get();
139+
}
140+
141+
// Common configuration getters
142+
143+
int64_t hf_config::get_hidden_size() const {
144+
int64_t val = 0;
145+
// Try multiple possible keys
146+
if (get_int("hidden_size", val)) return val;
147+
if (get_int("d_model", val)) return val;
148+
if (get_int("n_embd", val)) return val;
149+
return 0;
150+
}
151+
152+
int64_t hf_config::get_num_hidden_layers() const {
153+
int64_t val = 0;
154+
if (get_int("num_hidden_layers", val)) return val;
155+
if (get_int("n_layers", val)) return val;
156+
if (get_int("n_layer", val)) return val;
157+
if (get_int("num_layers", val)) return val;
158+
return 0;
159+
}
160+
161+
int64_t hf_config::get_num_attention_heads() const {
162+
int64_t val = 0;
163+
if (get_int("num_attention_heads", val)) return val;
164+
if (get_int("n_heads", val)) return val;
165+
if (get_int("n_head", val)) return val;
166+
return 0;
167+
}
168+
169+
int64_t hf_config::get_num_key_value_heads() const {
170+
int64_t val = 0;
171+
if (get_int("num_key_value_heads", val)) return val;
172+
// If not specified, defaults to num_attention_heads (MHA)
173+
return get_num_attention_heads();
174+
}
175+
176+
int64_t hf_config::get_intermediate_size() const {
177+
int64_t val = 0;
178+
if (get_int("intermediate_size", val)) return val;
179+
if (get_int("n_inner", val)) return val;
180+
return 0;
181+
}
182+
183+
int64_t hf_config::get_vocab_size() const {
184+
int64_t val = 0;
185+
if (get_int("vocab_size", val)) return val;
186+
if (get_int("padded_vocab_size", val)) return val;
187+
return 0;
188+
}
189+
190+
int64_t hf_config::get_max_position_embeddings() const {
191+
int64_t val = 0;
192+
if (get_int("max_position_embeddings", val)) return val;
193+
if (get_int("n_positions", val)) return val;
194+
if (get_int("n_ctx", val)) return val;
195+
return 0;
196+
}
197+
198+
double hf_config::get_rms_norm_eps() const {
199+
double val = 0;
200+
if (get_float("rms_norm_eps", val)) return val;
201+
if (get_float("layer_norm_eps", val)) return val;
202+
if (get_float("layer_norm_epsilon", val)) return val;
203+
return 1e-5; // common default
204+
}
205+
206+
std::string hf_config::get_rope_scaling_type() const {
207+
if (!config) {
208+
return "";
209+
}
210+
211+
// Check for rope_scaling object
212+
if (config->contains("rope_scaling") && (*config)["rope_scaling"].is_object()) {
213+
const auto & rope_scaling = (*config)["rope_scaling"];
214+
if (rope_scaling.contains("type") && rope_scaling["type"].is_string()) {
215+
return rope_scaling["type"].get<std::string>();
216+
}
217+
}
218+
219+
return "";
220+
}

src/llama-hf-config.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#pragma once
2+
3+
#include <map>
4+
#include <memory>
5+
#include <string>
6+
#include <vector>
7+
8+
#include "../vendor/nlohmann/json.hpp"
9+
10+
// HuggingFace model configuration
11+
class hf_config {
12+
public:
13+
hf_config() = default;
14+
~hf_config() = default;
15+
16+
// Load config from file
17+
bool load_from_file(const std::string & config_path);
18+
19+
// Load config from JSON string
20+
bool load_from_string(const std::string & json_str);
21+
22+
// Get architecture name (e.g., "LlamaForCausalLM", "MistralForCausalLM")
23+
std::string get_architecture() const;
24+
25+
// Get a configuration value as integer
26+
bool get_int(const std::string & key, int64_t & out) const;
27+
28+
// Get a configuration value as float
29+
bool get_float(const std::string & key, double & out) const;
30+
31+
// Get a configuration value as string
32+
bool get_string(const std::string & key, std::string & out) const;
33+
34+
// Get a configuration value as bool
35+
bool get_bool(const std::string & key, bool & out) const;
36+
37+
// Check if a key exists
38+
bool has_key(const std::string & key) const;
39+
40+
// Get raw JSON object (for advanced users)
41+
const nlohmann::json * get_json() const;
42+
43+
// Get last error message
44+
const std::string & get_error() const { return error_msg; }
45+
46+
// Common configuration getters
47+
int64_t get_hidden_size() const;
48+
int64_t get_num_hidden_layers() const;
49+
int64_t get_num_attention_heads() const;
50+
int64_t get_num_key_value_heads() const;
51+
int64_t get_intermediate_size() const;
52+
int64_t get_vocab_size() const;
53+
int64_t get_max_position_embeddings() const;
54+
double get_rms_norm_eps() const;
55+
std::string get_rope_scaling_type() const;
56+
57+
private:
58+
std::unique_ptr<nlohmann::json> config;
59+
std::string error_msg;
60+
61+
// Helper to get value, checking nested configs (text_config, vision_config)
62+
template<typename T>
63+
bool get_value_with_fallback(const std::string & key, T & out) const;
64+
};

0 commit comments

Comments
 (0)