Skip to content

Commit d0d1105

Browse files
authored
enable multi-transformer defination (#966)
* add llama 3.1 8b support * make Model and ModelArgs as model definition entrance * make model definition support multiple transformer * make model definition support multiple transformer * make model definition support multiple transformer * make input arg static in Model to support export * fix bugs for gguf and et in new model definition architecture * retrieve text transformer arg from modelargs * add set_cache funtion to Model to work around PTEModel issue * make torchchat rely on torchtune * remove export_util * extra torchtune dependency
1 parent 0922e65 commit d0d1105

File tree

9 files changed

+133
-84
lines changed

9 files changed

+133
-84
lines changed

build/builder.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,23 @@
1212
from typing import Any, Dict, Optional, Tuple, Union
1313

1414
import torch
15-
import torch.nn as nn
16-
from torch.distributed.device_mesh import DeviceMesh
1715
import torch._dynamo.config
1816
import torch._inductor.config
17+
import torch.nn as nn
1918

2019
from config.model_config import resolve_model_config
21-
from distributed import init_distributed, ParallelDims, parallelize_llama
20+
from distributed import (
21+
init_distributed,
22+
launch_distributed,
23+
ParallelDims,
24+
parallelize_llama,
25+
)
2226
from quantization.quantize import quantize_model
27+
from torch.distributed.device_mesh import DeviceMesh
2328
from utils.measure_time import measure_time
2429

25-
from build.model import Transformer
30+
from build.model import Model
2631
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
27-
from distributed import launch_distributed
2832

2933

3034
@dataclass
@@ -210,7 +214,7 @@ def __post_init__(self):
210214

211215
def validate_model(
212216
self,
213-
model: Transformer,
217+
model: Model,
214218
model_description: str = "model",
215219
) -> None:
216220
if model is None:
@@ -221,7 +225,7 @@ def validate_model(
221225

222226
is_tiktoken = self.is_tiktoken
223227
is_sentencepiece = self.is_sentencepiece
224-
use_tiktoken = model.config.use_tiktoken
228+
use_tiktoken = model.config.text_transformer_args.use_tiktoken
225229

226230
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
227231
raise RuntimeError(
@@ -298,11 +302,11 @@ def _unset_gguf_kwargs(builder_args):
298302
def _init_model_on_meta_device(builder_args):
299303
with torch.device("meta"):
300304
if builder_args.params_path:
301-
return Transformer.from_params(builder_args.params_path)
305+
return Model.from_params(builder_args.params_path)
302306
elif builder_args.params_table:
303-
return Transformer.from_table(builder_args.params_table)
307+
return Model.from_table(builder_args.params_table)
304308
else:
305-
return Transformer.from_name(builder_args.checkpoint_path.parent.name)
309+
return Model.from_name(builder_args.checkpoint_path.parent.name)
306310

307311

308312
def _load_model_gguf(builder_args, only_config=False):
@@ -311,7 +315,7 @@ def _load_model_gguf(builder_args, only_config=False):
311315
kwargs = {}
312316
else:
313317
kwargs = builder_args.gguf_kwargs
314-
model = Transformer.from_gguf(builder_args.gguf_path, **kwargs)
318+
model = Model.from_gguf(builder_args.gguf_path, **kwargs)
315319
return model
316320

317321

@@ -334,7 +338,6 @@ def _load_model_default(builder_args, only_config=False):
334338
mmap=True,
335339
)
336340
)
337-
338341
checkpoint = {}
339342
for key in cps[0].keys():
340343
if not torch.allclose(cps[0][key], cps[1][key]):
@@ -355,9 +358,10 @@ def _load_model_default(builder_args, only_config=False):
355358

356359
if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
357360
checkpoint = checkpoint["model"]
361+
362+
checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()}
358363

359-
model.load_state_dict(checkpoint, assign=True, strict=False)
360-
364+
model.load_state_dict(checkpoint, assign=True, strict=True)
361365
return model
362366

363367

@@ -380,11 +384,13 @@ def _maybe_init_distributed(
380384
"""
381385
if not builder_args.use_distributed:
382386
return None, None
383-
dist_config = 'llama3_8B.toml' # TODO - integrate with chat cmd line
387+
dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line
384388

385389
world_mesh, parallel_dims = launch_distributed(dist_config)
386390

387-
assert world_mesh is not None and parallel_dims is not None, f"failed to launch distributed using {dist_config}"
391+
assert (
392+
world_mesh is not None and parallel_dims is not None
393+
), f"failed to launch distributed using {dist_config}"
388394

389395
return world_mesh, parallel_dims
390396

@@ -523,7 +529,7 @@ def _initialize_model(
523529
if builder_args.setup_caches:
524530
with torch.device(builder_args.device):
525531
model.setup_caches(
526-
max_batch_size=1, max_seq_length=max_seq_length or model.config.max_seq_length
532+
max_batch_size=1, max_seq_length=max_seq_length or model.config.text_transformer_args.max_seq_length
527533
)
528534

529535
model.to(dtype=builder_args.precision)

build/gguf_loader.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414
import torch
1515

16+
from build.gguf_util import Q4_0, to_float
17+
from build.model import Model, ModelArgs, TransformerArgs
18+
1619
from gguf import GGUFValueType
1720
from quantization.qops import LinearInt4 as WeightOnlyInt4Linear
1821
from quantization.quantize import pack_scales_and_zeros
19-
from build.gguf_util import Q4_0, to_float
20-
from build.model import TransformerArgs, Transformer
2122

2223
logger: logging.Logger = logging.getLogger(__name__)
2324

@@ -41,6 +42,7 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
4142
result = copy.deepcopy(gguf_name)
4243
for gguf_string, replacement in _name_replacements:
4344
result = result.replace(gguf_string, replacement)
45+
result = "text_transformer." + result
4446
return result
4547

4648

@@ -107,22 +109,24 @@ def load_model(gguf_file: str) -> torch.nn.Module:
107109
arch = metadata["general.architecture"]
108110
assert arch == "llama", "Only LLaMa models are supported by this converter."
109111

110-
model_args = TransformerArgs(
111-
dim=metadata[f"{arch}.embedding_length"],
112-
n_layers=metadata[f"{arch}.block_count"],
113-
n_heads=metadata[f"{arch}.attention.head_count"],
114-
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
115-
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
116-
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
117-
hidden_dim=metadata[f"{arch}.feed_forward_length"],
112+
model_args = ModelArgs(
113+
TransformerArgs(
114+
dim=metadata[f"{arch}.embedding_length"],
115+
n_layers=metadata[f"{arch}.block_count"],
116+
n_heads=metadata[f"{arch}.attention.head_count"],
117+
n_local_heads=metadata[f"{arch}.attention.head_count_kv"],
118+
vocab_size=len(metadata["tokenizer.ggml.tokens"]),
119+
norm_eps=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
120+
hidden_dim=metadata[f"{arch}.feed_forward_length"],
121+
)
118122
)
119123

120124
# TODO: what to do with rope args like
121125
# metadata.get(f"{arch}.rope.freq_base", None)
122126
# metadata.get(f"{arch}.rope.dimension_count", None)
123127

124128
with torch.device("meta"):
125-
model = Transformer(model_args)
129+
model = Model(model_args)
126130
return model
127131

128132

build/model.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,46 @@ def __post_init__(self):
5959
self.use_tiktoken = self.use_tiktoken == "True"
6060

6161
@classmethod
62-
def from_params(cls, params_path):
62+
def from_params(cls, params):
6363
replace = [("rope_theta", "rope_base"), ("n_kv_heads", "n_local_heads")]
64-
with open(params_path, "r") as f:
65-
params = json.loads(f.read())
66-
# Patch for llama3
67-
for _from, _to in replace:
68-
if _from in params:
69-
params[_to] = params.pop(_from)
64+
for _from, _to in replace:
65+
if _from in params:
66+
params[_to] = params.pop(_from)
7067
return cls(**params)
7168

69+
@dataclass
70+
class ModelArgs:
71+
text_transformer_args: TransformerArgs
72+
73+
def __post_init__(self):
74+
assert self.text_transformer_args is not None
75+
assert type(self.text_transformer_args) == TransformerArgs
76+
77+
@classmethod
78+
def from_params(cls, params_path):
79+
with open(params_path, "r") as f:
80+
loaded_params = json.loads(f.read())
81+
82+
try:
83+
# try to interpret as a single transformer config
84+
text_transformer_args = TransformerArgs.from_params(
85+
loaded_params
86+
)
87+
except TypeError:
88+
# try to interpret as a dict of transformer configs
89+
for name, params in loaded_params.items():
90+
if name == "text":
91+
text_transformer_args = TransformerArgs.from_params(params)
92+
else:
93+
raise ValueError(f"Unknown transformer name {name}")
94+
95+
return cls(text_transformer_args)
96+
7297
@classmethod
7398
def from_table(cls, name: str):
7499
json_path = config_path / f"{name}.json"
75100
if json_path.is_file():
76-
return TransformerArgs.from_params(json_path)
101+
return ModelArgs.from_params(json_path)
77102
else:
78103
known_model_params = [
79104
config.replace(".json", "") for config in os.listdir(config_path)
@@ -86,7 +111,7 @@ def from_table(cls, name: str):
86111
def from_name(cls, name: str):
87112
json_path = config_path / f"{name}.json"
88113
if Path(json_path).is_file():
89-
return TransformerArgs.from_params(json_path)
114+
return ModelArgs.from_params(json_path)
90115

91116
known_model_params = [
92117
config.replace(".json", "") for config in os.listdir(config_path)
@@ -113,7 +138,7 @@ def from_name(cls, name: str):
113138
f"Unknown model directory name {name}. Must be one of {known_model_params}."
114139
)
115140

116-
return TransformerArgs.from_params(config_path / f"{config[0]}.json")
141+
return ModelArgs.from_params(config_path / f"{config[0]}.json")
117142

118143

119144
class KVCache(nn.Module):
@@ -144,6 +169,40 @@ def update(self, input_pos, k_val, v_val):
144169
return k_out, v_out
145170

146171

172+
class Model(nn.Module):
173+
def __init__(self, config: ModelArgs) -> None:
174+
super().__init__()
175+
self.config = config
176+
self.text_transformer = Transformer(config.text_transformer_args)
177+
178+
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
179+
return self.text_transformer(idx, input_pos)
180+
181+
def setup_caches(self, max_batch_size, max_seq_length):
182+
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
183+
184+
@classmethod
185+
def from_name(cls, name: str):
186+
return cls(ModelArgs.from_name(name))
187+
188+
@classmethod
189+
def from_table(cls, name: str):
190+
return cls(ModelArgs.from_table(name))
191+
192+
@classmethod
193+
def from_params(cls, params_path: str):
194+
return cls(ModelArgs.from_params(params_path))
195+
196+
@classmethod
197+
def from_gguf(cls, gguf_path: str, **kwargs):
198+
from build.gguf_loader import load_model_and_state_dict
199+
200+
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
201+
if state_dict != {}:
202+
model.load_state_dict(state_dict, assign=True)
203+
return model
204+
205+
147206
class Transformer(nn.Module):
148207
def __init__(self, config: TransformerArgs) -> None:
149208
super().__init__()
@@ -180,7 +239,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
180239
self.config.dim // self.config.n_heads,
181240
self.config.block_size * 2,
182241
self.config.rope_base,
183-
use_scaled = self.config.use_scaled_rope,
242+
use_scaled=self.config.use_scaled_rope,
184243
)
185244
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
186245
causal_mask = torch.tril(
@@ -201,27 +260,6 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
201260
# print(f"logits shape: {logits.shape}")
202261
return logits
203262

204-
@classmethod
205-
def from_name(cls, name: str):
206-
return cls(TransformerArgs.from_name(name))
207-
208-
@classmethod
209-
def from_table(cls, name: str):
210-
return cls(TransformerArgs.from_table(name))
211-
212-
@classmethod
213-
def from_params(cls, params_path: str):
214-
return cls(TransformerArgs.from_params(params_path))
215-
216-
@classmethod
217-
def from_gguf(cls, gguf_path: str, **kwargs):
218-
from build.gguf_loader import load_model_and_state_dict
219-
220-
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
221-
if state_dict != {}:
222-
model.load_state_dict(state_dict, assign=True)
223-
return model
224-
225263

226264
class TransformerBlock(nn.Module):
227265
def __init__(self, config: TransformerArgs) -> None:
@@ -388,6 +426,7 @@ def apply_scaling(freqs: torch.Tensor):
388426
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
389427
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
390428

429+
391430
def precompute_freqs_cis(
392431
n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False
393432
) -> Tensor:

distributed/parallelize_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def apply_tp(
5959
# after we apply TP to the model. Because we don't want to change model code
6060
# when applying TP. We need to have change to ensure KVCache has the correct
6161
# size as k and v.
62-
model.config.n_local_heads = model.config.n_local_heads // tp_mesh.size()
62+
model.config.text_transformer_args.n_local_heads = model.config.text_transformer_args.n_local_heads // tp_mesh.size()
6363

6464
# Apply tensor parallelism to every transformer block
6565
for transformer_block in model.layers:

docs/ADVANCED-USERS.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,21 @@ architecture, provided you have the model weights in llama format, the
112112
model parameters and the tokenizer model used by your language model.
113113

114114
Some common models are recognized by torchchat based on their filename
115-
through `Transformer.from_name()` to perform a fuzzy match against a
115+
through `Model.from_name()` to perform a fuzzy match against a
116116
table of known model architectures. Alternatively, you can specify the
117117
index into that table with the option `--params-table ${INDEX}` where
118118
the index is the lookup key key in the [the list of known
119119
pconfigurations](https://github.com/pytorch/torchchat/tree/main/build/known_model_params)
120120
For example, for the stories15M model, this would be expressed as
121121
`--params-table stories15M`. (We use the model constructor
122-
`Transformer.from_table()`)
122+
`Model.from_table()`)
123123

124124
For models using a configuration not in the list of known
125125
configurations, you can construct the model by initializing the
126126
`TransformerArgs` dataclass that controls model construction from a
127127
parameter json using the `params-path ${PARAMS_PATH}` containing the
128-
appropriate model parameters to initialize the `TransformerArgs` for the
129-
model. (We use the model constructor `Transformer.from_params()`).
128+
appropriate model parameters to initialize the `ModelArgs` for the
129+
model. (We use the model constructor `Model.from_params()`).
130130

131131
The parameter file should be in JSON format specifying these
132132
parameters. You can find the `TransformerArgs` data class in

0 commit comments

Comments
 (0)