|
| 1 | +#include "llama-hf-config.h" |
| 2 | + |
| 3 | +#include <fstream> |
| 4 | +#include "../vendor/nlohmann/json.hpp" |
| 5 | + |
| 6 | +using json = nlohmann::json; |
| 7 | + |
| 8 | +bool hf_config::load_from_file(const std::string & config_path) { |
| 9 | + std::ifstream f(config_path); |
| 10 | + if (!f.is_open()) { |
| 11 | + error_msg = "Failed to open config file: " + config_path; |
| 12 | + return false; |
| 13 | + } |
| 14 | + |
| 15 | + try { |
| 16 | + config = std::make_unique<json>(); |
| 17 | + f >> *config; |
| 18 | + } catch (const std::exception & e) { |
| 19 | + error_msg = std::string("Failed to parse config JSON: ") + e.what(); |
| 20 | + return false; |
| 21 | + } |
| 22 | + |
| 23 | + return true; |
| 24 | +} |
| 25 | + |
| 26 | +bool hf_config::load_from_string(const std::string & json_str) { |
| 27 | + try { |
| 28 | + config = std::make_unique<json>(json::parse(json_str)); |
| 29 | + } catch (const std::exception & e) { |
| 30 | + error_msg = std::string("Failed to parse config JSON: ") + e.what(); |
| 31 | + return false; |
| 32 | + } |
| 33 | + |
| 34 | + return true; |
| 35 | +} |
| 36 | + |
| 37 | +std::string hf_config::get_architecture() const { |
| 38 | + if (!config) { |
| 39 | + return ""; |
| 40 | + } |
| 41 | + |
| 42 | + // Check for architectures array (most common) |
| 43 | + if (config->contains("architectures") && (*config)["architectures"].is_array()) { |
| 44 | + const auto & archs = (*config)["architectures"]; |
| 45 | + if (!archs.empty() && archs[0].is_string()) { |
| 46 | + return archs[0].get<std::string>(); |
| 47 | + } |
| 48 | + } |
| 49 | + |
| 50 | + // Check text_config (for multimodal models) |
| 51 | + if (config->contains("text_config") && (*config)["text_config"].is_object()) { |
| 52 | + const auto & text_config = (*config)["text_config"]; |
| 53 | + if (text_config.contains("architectures") && text_config["architectures"].is_array()) { |
| 54 | + const auto & archs = text_config["architectures"]; |
| 55 | + if (!archs.empty() && archs[0].is_string()) { |
| 56 | + return archs[0].get<std::string>(); |
| 57 | + } |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + // Check for ssm_cfg (Mamba models) |
| 62 | + if (config->contains("ssm_cfg") && (*config)["ssm_cfg"].is_object()) { |
| 63 | + const auto & ssm_cfg = (*config)["ssm_cfg"]; |
| 64 | + if (ssm_cfg.contains("layer") && ssm_cfg["layer"].is_string()) { |
| 65 | + return ssm_cfg["layer"].get<std::string>() + "ForCausalLM"; |
| 66 | + } |
| 67 | + } |
| 68 | + |
| 69 | + return ""; |
| 70 | +} |
| 71 | + |
| 72 | +template<typename T> |
| 73 | +bool hf_config::get_value_with_fallback(const std::string & key, T & out) const { |
| 74 | + if (!config) { |
| 75 | + return false; |
| 76 | + } |
| 77 | + |
| 78 | + // First try root level |
| 79 | + if (config->contains(key)) { |
| 80 | + try { |
| 81 | + out = (*config)[key].get<T>(); |
| 82 | + return true; |
| 83 | + } catch (const std::exception &) { |
| 84 | + return false; |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + // Try text_config (for multimodal models) |
| 89 | + if (config->contains("text_config") && (*config)["text_config"].is_object()) { |
| 90 | + const auto & text_config = (*config)["text_config"]; |
| 91 | + if (text_config.contains(key)) { |
| 92 | + try { |
| 93 | + out = text_config[key].get<T>(); |
| 94 | + return true; |
| 95 | + } catch (const std::exception &) { |
| 96 | + return false; |
| 97 | + } |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + return false; |
| 102 | +} |
| 103 | + |
| 104 | +bool hf_config::get_int(const std::string & key, int64_t & out) const { |
| 105 | + return get_value_with_fallback(key, out); |
| 106 | +} |
| 107 | + |
| 108 | +bool hf_config::get_float(const std::string & key, double & out) const { |
| 109 | + return get_value_with_fallback(key, out); |
| 110 | +} |
| 111 | + |
| 112 | +bool hf_config::get_string(const std::string & key, std::string & out) const { |
| 113 | + return get_value_with_fallback(key, out); |
| 114 | +} |
| 115 | + |
| 116 | +bool hf_config::get_bool(const std::string & key, bool & out) const { |
| 117 | + return get_value_with_fallback(key, out); |
| 118 | +} |
| 119 | + |
| 120 | +bool hf_config::has_key(const std::string & key) const { |
| 121 | + if (!config) { |
| 122 | + return false; |
| 123 | + } |
| 124 | + |
| 125 | + if (config->contains(key)) { |
| 126 | + return true; |
| 127 | + } |
| 128 | + |
| 129 | + // Check text_config |
| 130 | + if (config->contains("text_config") && (*config)["text_config"].is_object()) { |
| 131 | + return (*config)["text_config"].contains(key); |
| 132 | + } |
| 133 | + |
| 134 | + return false; |
| 135 | +} |
| 136 | + |
| 137 | +const nlohmann::json * hf_config::get_json() const { |
| 138 | + return config.get(); |
| 139 | +} |
| 140 | + |
| 141 | +// Common configuration getters |
| 142 | + |
| 143 | +int64_t hf_config::get_hidden_size() const { |
| 144 | + int64_t val = 0; |
| 145 | + // Try multiple possible keys |
| 146 | + if (get_int("hidden_size", val)) return val; |
| 147 | + if (get_int("d_model", val)) return val; |
| 148 | + if (get_int("n_embd", val)) return val; |
| 149 | + return 0; |
| 150 | +} |
| 151 | + |
| 152 | +int64_t hf_config::get_num_hidden_layers() const { |
| 153 | + int64_t val = 0; |
| 154 | + if (get_int("num_hidden_layers", val)) return val; |
| 155 | + if (get_int("n_layers", val)) return val; |
| 156 | + if (get_int("n_layer", val)) return val; |
| 157 | + if (get_int("num_layers", val)) return val; |
| 158 | + return 0; |
| 159 | +} |
| 160 | + |
| 161 | +int64_t hf_config::get_num_attention_heads() const { |
| 162 | + int64_t val = 0; |
| 163 | + if (get_int("num_attention_heads", val)) return val; |
| 164 | + if (get_int("n_heads", val)) return val; |
| 165 | + if (get_int("n_head", val)) return val; |
| 166 | + return 0; |
| 167 | +} |
| 168 | + |
| 169 | +int64_t hf_config::get_num_key_value_heads() const { |
| 170 | + int64_t val = 0; |
| 171 | + if (get_int("num_key_value_heads", val)) return val; |
| 172 | + // If not specified, defaults to num_attention_heads (MHA) |
| 173 | + return get_num_attention_heads(); |
| 174 | +} |
| 175 | + |
| 176 | +int64_t hf_config::get_intermediate_size() const { |
| 177 | + int64_t val = 0; |
| 178 | + if (get_int("intermediate_size", val)) return val; |
| 179 | + if (get_int("n_inner", val)) return val; |
| 180 | + return 0; |
| 181 | +} |
| 182 | + |
| 183 | +int64_t hf_config::get_vocab_size() const { |
| 184 | + int64_t val = 0; |
| 185 | + if (get_int("vocab_size", val)) return val; |
| 186 | + if (get_int("padded_vocab_size", val)) return val; |
| 187 | + return 0; |
| 188 | +} |
| 189 | + |
| 190 | +int64_t hf_config::get_max_position_embeddings() const { |
| 191 | + int64_t val = 0; |
| 192 | + if (get_int("max_position_embeddings", val)) return val; |
| 193 | + if (get_int("n_positions", val)) return val; |
| 194 | + if (get_int("n_ctx", val)) return val; |
| 195 | + return 0; |
| 196 | +} |
| 197 | + |
| 198 | +double hf_config::get_rms_norm_eps() const { |
| 199 | + double val = 0; |
| 200 | + if (get_float("rms_norm_eps", val)) return val; |
| 201 | + if (get_float("layer_norm_eps", val)) return val; |
| 202 | + if (get_float("layer_norm_epsilon", val)) return val; |
| 203 | + return 1e-5; // common default |
| 204 | +} |
| 205 | + |
| 206 | +std::string hf_config::get_rope_scaling_type() const { |
| 207 | + if (!config) { |
| 208 | + return ""; |
| 209 | + } |
| 210 | + |
| 211 | + // Check for rope_scaling object |
| 212 | + if (config->contains("rope_scaling") && (*config)["rope_scaling"].is_object()) { |
| 213 | + const auto & rope_scaling = (*config)["rope_scaling"]; |
| 214 | + if (rope_scaling.contains("type") && rope_scaling["type"].is_string()) { |
| 215 | + return rope_scaling["type"].get<std::string>(); |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + return ""; |
| 220 | +} |
0 commit comments