-
Notifications
You must be signed in to change notification settings - Fork 341
/
Copy pathtokenizer.py
292 lines (242 loc) · 9.97 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Provides op for tokenizing a dataset."""
from typing import Dict, Iterable, Union, Literal, Sequence, Collection, List
from pathlib import Path
import tensorflow as tf
import tensorflow_text as tftxt
from MaxText import max_logging
import transformers
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from sentencepiece import SentencePieceProcessor
Features = Dict[str, tf.Tensor]
class TikTokenTokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # pylint: disable=line-too-long
def __init__(self, model_path: str, add_bos: bool, add_eos: bool):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range(5, self.num_reserved_special_tokens - 5)]
self.special_tokens = {token: num_base_tokens + i for i, token in enumerate(special_tokens)}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
self.eos = add_eos
self.bos = add_bos
max_logging.log(f"Reloaded tiktoken model from {model_path}")
self.n_words: int = self.model.n_vocab
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.pad_id: int = -1
self.stop_tokens = {
self.special_tokens["<|end_of_text|>"],
self.special_tokens["<|eot_id|>"],
}
max_logging.log(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
def encode(
self,
s: str,
*,
allowed_special: Union[Literal["all"], Collection[str]] = (),
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_tokens ("all"|set[str]): allowed special tokens in string
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
assert isinstance(s, str)
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=set(allowed_special),
disallowed_special=disallowed_special,
)
)
if self.bos:
t.insert(0, self.bos_id)
if self.eos:
t.append(self.eos_id)
return t
def decode(self, t) -> str:
"""
Decodes a list of token IDs into a string.
Args:
t (List[int]): The list of token IDs to be decoded.
Returns:
str: The decoded string.
"""
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(t)
@staticmethod
def _split_whitespaces_or_nonwhitespaces(s: str, max_consecutive_slice_len: int):
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i, _ in enumerate(s):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]
class SentencePieceTokenizer:
"""
Tokenizing and encoding/decoding text using the Sentencepiece tokenizer loaded with tensorflow_text
"""
def __init__(self, model_path: str, add_bos: bool, add_eos: bool):
max_logging.log(f"Tokenizer path: {model_path}")
with tf.io.gfile.GFile(model_path, "rb") as model_fp:
sp_model = model_fp.read()
self.sp_tokenizer = tftxt.SentencepieceTokenizer(model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=False)
self.pad_id = self.sp_tokenizer.string_to_id("<pad>")
self.unk_id = self.sp_tokenizer.string_to_id("<unk>")
def encode(self, s: str) -> List[int]:
return self.sp_tokenizer.tokenize(s)
def decode(self, t: Sequence[int]) -> str:
return self.sp_tokenizer.detokenize(t)
class SentencePieceTokenizerGrain:
"""
Tokenizing and encoding/decoding text using the Sentencepiece tokenizer loaded with sentencepiece
"""
def __init__(self, model_path: str, add_bos: bool, add_eos: bool):
max_logging.log(f"Loading sentencepiece tokenizer: {model_path}")
self._tokenizer_model = SentencePieceProcessor()
self._tokenizer_model.Load(model_path)
self.pad_id = self._tokenizer_model.pad_id()
self.unk_id = self._tokenizer_model.unk_id()
self.bos_id = self._tokenizer_model.bos_id()
self.eos_id = self._tokenizer_model.eos_id()
self.add_bos = add_bos
self.add_eos = add_eos
def encode(self, s: str) -> List[int]:
token_ids = self._tokenizer_model.EncodeAsIds(s)
if self.add_bos:
token_ids = [self.bos_id] + token_ids
if self.add_eos:
token_ids += [self.eos_id]
return token_ids
def decode(self, t: Sequence[int]) -> str:
return self._tokenizer_model.DecodeIds(t)
class HFTokenizer:
"""
Tokenizing using huggingface tokenizer
"""
def __init__(self, model_path: str, add_bos: bool, add_eos: bool, hf_access_token: str):
max_logging.log(f"Loading HF tokenizer: {model_path}")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
add_bos_token=add_bos,
add_eos_token=add_eos,
token=hf_access_token,
)
self.pad_id = self.tokenizer.pad_token_id
self.unk_id = self.tokenizer.unk_token_id
self.bos_id = self.tokenizer.bos_token_id
self.eos_id = self.tokenizer.eos_token_id
def encode(self, s: str) -> List[int]:
return self.tokenizer.encode(s)
def decode(self, t: Sequence[int]) -> str:
return self.tokenizer.decode(t)
def build_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token, dataset_type):
"""Loads the tokenizer at `tokenizer_path`"""
max_logging.log(f"Tokenizer path: {tokenizer_path}")
if tokenizer_type == "tiktoken":
assert "tiktoken" in tokenizer_path, f"Invalid tokenizer type: {tokenizer_type} chosen for {tokenizer_path}"
return TikTokenTokenizer(tokenizer_path, add_bos, add_eos)
elif tokenizer_type == "huggingface":
return HFTokenizer(tokenizer_path, add_bos, add_eos, hf_access_token)
elif tokenizer_type == "sentencepiece":
if dataset_type == "tfds":
return SentencePieceTokenizer(tokenizer_path, add_bos, add_eos)
else:
return SentencePieceTokenizerGrain(tokenizer_path, add_bos, add_eos)
else:
raise ValueError(f"Invalid tokenizer_type:{tokenizer_type} chosen in config")
def TokenizeOp(tokenizer, features: Features, data_keys: Iterable[str] = ("inputs", "targets")) -> Features:
"""Op for tokenization"""
def _process_string(string_tensor):
# Extract string value and decode it if necessary
string_value = string_tensor.numpy().decode("utf-8")
# encode and extract the tokenized integers
modified_string = tokenizer.encode(string_value)
return [modified_string]
for k in data_keys:
if isinstance(tokenizer, (TikTokenTokenizer, HFTokenizer)):
features[k] = tf.py_function(_process_string, [features[k]], Tout=[tf.int32])[0]
elif isinstance(tokenizer, SentencePieceTokenizer):
features[k] = tokenizer.encode(features[k])
return features