-
Notifications
You must be signed in to change notification settings - Fork 359
/
Copy pathmodels.py
714 lines (623 loc) · 25.3 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
"""Wrapper for fine-tuned HuggingFace models in LIT."""
# TODO(b/261736863): Update to PEP 585 typings, consider using f-strings, and
# make common substrings into module CONSTANTS.
from collections.abc import Iterable, Sequence
import os
import re
import threading
from typing import Any, Optional
import attr
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.examples.glue import model_utils
from lit_nlp.lib import file_cache
from lit_nlp.lib import utils
import numpy as np
import tensorflow as tf
import tf_keras as keras
import transformers
os.environ["TF_USE_LEGACY_KERAS"] = "1"
JsonDict = lit_types.JsonDict
Spec = lit_types.Spec
TFSequenceClassifierOutput = (
transformers.modeling_tf_outputs.TFSequenceClassifierOutput
)
@attr.s(auto_attribs=True, kw_only=True)
class GlueModelConfig(object):
"""Config options for a GlueModel."""
# Preprocessing options
max_seq_length: int = 128
inference_batch_size: int = 32
# Input options
text_a_name: str = "sentence1"
text_b_name: Optional[str] = "sentence2" # set to None for single-segment
label_name: str = "label"
# Output options
labels: Optional[list[str]] = None # set to None for regression
null_label_idx: Optional[int] = None
compute_grads: bool = True # if True, compute and return gradients.
output_attention: bool = True
output_embeddings: bool = True
@classmethod
def init_spec(cls) -> lit_types.Spec:
return {
"model_name_or_path": lit_types.String(
default="bert-base-uncased",
required=False,
),
"max_seq_length": lit_types.Integer(
default=128,
max_val=512,
min_val=1,
required=False,
),
"inference_batch_size": lit_types.Integer(
default=32,
max_val=64,
min_val=1,
required=False,
),
"compute_grads": lit_types.Boolean(default=True, required=False),
"output_attention": lit_types.Boolean(default=True, required=False),
"output_embeddings": lit_types.Boolean(default=True, required=False),
}
class GlueModel(lit_model.BatchedModel):
"""GLUE benchmark model, using Keras/TF2 and Huggingface Transformers.
This is a general-purpose classification or regression model. It works for
one- or two-segment input, and predicts either a multiclass label or
a regression score. See GlueModelConfig for available options.
This implements the LIT API for inference (e.g. input_spec(), output_spec(),
and predict()), but also provides a train() method to run fine-tuning.
This is a full-featured implementation, which includes embeddings, attention,
gradients, as well as support for the different input and output types above.
"""
def _verify_num_layers(self, hidden_states: Sequence[Any]):
"""Verify correct # of layer activations returned."""
# First entry is embeddings, then output from each transformer layer.
expected_hidden_states_len = self.model.config.num_hidden_layers + 1
actual_hidden_states_len = len(hidden_states)
if actual_hidden_states_len != expected_hidden_states_len:
raise ValueError(
"Unexpected size of hidden_states. Should be one "
"more than the number of hidden layers to account "
"for the embeddings. Expected "
f"{expected_hidden_states_len}, got "
f"{actual_hidden_states_len}."
)
@property
def is_regression(self) -> bool:
return self.config.labels is None
# TODO(b/254110131): Move file_cache.cached_path() call inside this __init__
# function to reduce boilerplate in other locations (e.g., TCAV tests).
def __init__(self, model_name_or_path="bert-base-uncased", **config_kw):
self.config = GlueModelConfig(**config_kw)
self._load_model(model_name_or_path)
self._lock = threading.Lock()
def _load_model(self, model_name_or_path):
"""Load model. Can be overridden for testing."""
# Normally path is a directory; if it's an archive file, download and
# extract to the transformers cache.
if model_name_or_path.endswith(".tar.gz"):
model_name_or_path = file_cache.cached_path(
model_name_or_path, extract_compressed_file=True
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path
)
self.vocab = self.tokenizer.convert_ids_to_tokens(
range(len(self.tokenizer))
)
model_config = transformers.AutoConfig.from_pretrained(
model_name_or_path,
num_labels=1 if self.is_regression else len(self.config.labels),
return_dict=False, # default for training; overridden for predict
output_attentions=self.config.output_attention,
)
self.model = model_utils.load_pretrained(
transformers.TFAutoModelForSequenceClassification,
model_name_or_path,
config=model_config,
)
def _get_tokens(self, ex: JsonDict, field_name: str) -> list[str]:
with self._lock:
return ex.get("tokens_" + field_name) or self.tokenizer.tokenize(
ex[field_name]
)
def _preprocess(self, inputs: Iterable[JsonDict]) -> dict[str, tf.Tensor]:
# Use pretokenized input if available.
tokens_a = [self._get_tokens(ex, self.config.text_a_name) for ex in inputs]
tokens_b = None
if self.config.text_b_name:
tokens_b = [
self._get_tokens(ex, self.config.text_b_name) for ex in inputs
]
# Use custom tokenizer call to make sure we don't mangle pre-split
# wordpieces in pretokenized input.
encoded_input = model_utils.batch_encode_pretokenized(
self.tokenizer,
tokens_a,
tokens_b,
max_length=self.config.max_seq_length,
)
return encoded_input # pytype: disable=bad-return-type
def _make_dataset(self, inputs: Iterable[JsonDict]) -> tf.data.Dataset:
"""Make a tf.data.Dataset from inputs in LIT format."""
encoded_input = self._preprocess(inputs)
if self.is_regression:
labels = tf.constant(
[ex[self.config.label_name] for ex in inputs], dtype=tf.float32
)
else:
indexes = []
if self.config.labels is not None:
for ex in inputs:
indexes.append(self.config.labels.index(ex[self.config.label_name]))
labels = tf.constant(
indexes,
dtype=tf.int64,
)
# encoded_input is actually a transformers.BatchEncoding
# object, which tf.data.Dataset doesn't like. Convert to a regular dict.
return tf.data.Dataset.from_tensor_slices((dict(encoded_input), labels))
def train(
self,
train_inputs: list[JsonDict],
validation_inputs: list[JsonDict],
learning_rate=2e-5,
batch_size=32,
num_epochs=3,
keras_callbacks=None,
):
"""Run fine-tuning."""
train_dataset = (
self._make_dataset(train_inputs)
.shuffle(128)
.batch(batch_size)
.repeat(-1)
)
# Use larger batch for validation since inference is about 1/2 memory usage
# of backprop.
eval_batch_size = 2 * batch_size
validation_dataset = self._make_dataset(validation_inputs).batch(
eval_batch_size
)
# Prepare model for training.
opt = keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08)
if self.is_regression:
loss = keras.losses.MeanSquaredError()
metric = keras.metrics.RootMeanSquaredError("rmse")
else:
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = keras.metrics.SparseCategoricalAccuracy("accuracy")
self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
steps_per_epoch = len(train_inputs) // batch_size
validation_steps = len(validation_inputs) // eval_batch_size
history = self.model.fit(
train_dataset,
epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation_dataset,
validation_steps=validation_steps,
callbacks=keras_callbacks,
verbose=2,
)
return history
def save(self, path: str):
"""Save model weights and tokenizer info.
To re-load, pass the path to the constructor instead of the name of a
base model.
Args:
path: directory to save to. Will write several files here.
"""
if not os.path.isdir(path):
os.mkdir(path)
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
def _segment_slicers(self, tokens: list[str]):
"""Slicers along the tokens dimension for each segment.
For tokens ['[CLS]', a0, a1, ..., '[SEP]', b0, b1, ..., '[SEP]'],
we want to get the slices [a0, a1, ...] and [b0, b1, ...]
Args:
tokens: <string>[num_tokens], including special tokens
Returns:
(slicer_a, slicer_b), slice objects
"""
try:
split_point = tokens.index(self.tokenizer.sep_token)
except ValueError:
split_point = len(tokens) - 1
slicer_a = slice(1, split_point) # start after [CLS]
slicer_b = slice(split_point + 1, len(tokens) - 1) # end before last [SEP]
return slicer_a, slicer_b
def _postprocess(self, output: dict[str, Any]):
"""Per-example postprocessing, on NumPy output."""
ntok = output.pop("ntok")
output["tokens"] = self.tokenizer.convert_ids_to_tokens(
output.pop("input_ids")[:ntok]
)
# Tokens for each segment, individually.
slicer_a, slicer_b = self._segment_slicers(output["tokens"])
output["tokens_" + self.config.text_a_name] = output["tokens"][slicer_a]
if self.config.text_b_name:
output["tokens_" + self.config.text_b_name] = output["tokens"][slicer_b]
# Embeddings for each segment, individually.
if self.config.output_embeddings:
output["input_embs_" + self.config.text_a_name] = output["input_embs"][
slicer_a
]
if self.config.text_b_name:
output["input_embs_" + self.config.text_b_name] = output["input_embs"][
slicer_b
]
# Gradients for each segment, individually.
if self.config.compute_grads:
# Gradients for the CLS token.
output["cls_grad"] = output["input_emb_grad"][0]
output["token_grad_" + self.config.text_a_name] = output[
"input_emb_grad"
][slicer_a]
if self.config.text_b_name:
output["token_grad_" + self.config.text_b_name] = output[
"input_emb_grad"
][slicer_b]
# TODO(b/294613507): remove output[self.config.label_name] once TCAV
# is updated.
if not self.is_regression:
# Return the label corresponding to the class index used for gradients.
output[self.config.label_name] = self.config.labels[
output[self.config.label_name]
] # pytype: disable=container-type-mismatch
# Remove "input_emb_grad" since it's not in the output spec.
del output["input_emb_grad"]
if not self.config.output_attention:
return output
# Process attention.
for key in output:
if not re.match(r"layer_(\d+)/attention", key):
continue
# Select only real tokens, since most of this matrix is padding.
# <float32>[num_heads, max_seq_length, max_seq_length]
# -> <float32>[num_heads, num_tokens, num_tokens]
output[key] = output[key][:, :ntok, :ntok].transpose((0, 2, 1))
# Make a copy of this array to avoid memory leaks, since NumPy otherwise
# keeps a pointer around that prevents the source array from being GCed.
output[key] = output[key].copy() # pytype: disable=attribute-error
return output
def _scatter_embs(
self, passed_input_embs, input_embs, batch_indices, offsets
):
"""Scatters custom passed embeddings into the default model embeddings.
Args:
passed_input_embs: <tf.float32>[num_scatter_tokens], the custom passed
embeddings to be scattered into the default model embeddings.
input_embs: the default model embeddings.
batch_indices: the indices of the embeddings to replace in the format
(batch_index, sequence_index).
offsets: the offset from which to scatter the custom embedding (number of
tokens from the start of the sequence).
Returns:
The default model embeddings with scattered custom embeddings.
"""
# <float32>[scatter_batch_size, num_tokens, emb_size]
filtered_embs = [emb for emb in passed_input_embs if emb is not None]
# Prepares update values that should be scattered in, i.e. one for each
# of the (scatter_batch_size * num_tokens) word embeddings.
# <np.float32>[scatter_batch_size * num_tokens, emb_size]
updates = np.concatenate(filtered_embs)
# Prepares indices in format (batch_index, sequence_index) for all
# values that should be scattered in, i.e. one for each of the
# (scatter_batch_size * num_tokens) word embeddings.
scatter_indices = []
for batch_index, sentence_embs, offset in zip(
batch_indices, filtered_embs, offsets
):
for token_index, _ in enumerate(sentence_embs):
scatter_indices.append([batch_index, token_index + offset])
# Scatters passed word embeddings into embeddings gathered from tokens.
# <tf.float32>[batch_size, num_tokens + num_special_tokens, emb_size]
return tf.tensor_scatter_nd_update(input_embs, scatter_indices, updates)
def scatter_all_embeddings(self, inputs, input_embs):
"""Scatters custom passed embeddings for text segment inputs.
Args:
inputs: the model inputs, which contain any custom embeddings to scatter.
input_embs: the default model embeddings.
Returns:
The default model embeddings with scattered custom embeddings.
"""
# Gets batch indices of any word embeddings that were passed for text_a.
passed_input_embs_a = [
ex.get("input_embs_" + self.config.text_a_name) for ex in inputs
]
batch_indices_a = [
index
for (index, emb) in enumerate(passed_input_embs_a)
if emb is not None
]
# If word embeddings were passed in for text_a, scatter them into the
# embeddings, gathered from the input ids. 1 is passed in as the offset
# for each, since text_a starts at index 1, after the [CLS] token.
if batch_indices_a:
input_embs = self._scatter_embs(
passed_input_embs_a,
input_embs,
batch_indices_a,
offsets=np.ones(len(batch_indices_a), dtype=np.int64),
)
if self.config.text_b_name:
# Gets batch indices of any word embeddings that were passed for text_b.
passed_input_embs_b = [
ex.get("input_embs_" + self.config.text_b_name) for ex in inputs
]
batch_indices_b = [
index
for (index, emb) in enumerate(passed_input_embs_b)
if emb is not None
]
# If word embeddings were also passed in for text_b, scatter them into the
# embeddings gathered from the input ids. The offsets are the [lengths
# of the corresponding text_a embeddings] + 2, since text_b starts after
# [CLS] [text_a tokens] [SEP]. (This assumes that text_b embeddings
# will only be passed together with text_a embeddings.)
if batch_indices_b:
lengths = np.array(
[len(embed) for embed in passed_input_embs_a if embed is not None]
)
input_embs = self._scatter_embs(
passed_input_embs_b,
input_embs,
batch_indices_b,
offsets=(lengths + 2),
)
return input_embs
def get_target_scores(self, inputs: Iterable[JsonDict], scores):
"""Get target-class scores, as a 1D tensor.
Args:
inputs: list of input examples
scores: <tf.float32>[batch_size, num_classes], either logits or probas
Returns:
<tf.float32>[batch_size] target scores for each input
"""
arg_max = tf.math.argmax(scores, axis=-1).numpy()
grad_classes = [
ex.get(self.config.label_name, arg_max[i])
for (i, ex) in enumerate(inputs)
]
# Convert the class names to indices if needed.
grad_idxs = []
for label in grad_classes:
if isinstance(label, str) and self.config.labels is not None:
grad_idxs.append(self.config.labels.index(label))
else:
grad_idxs.append(label)
# list of tuples (batch idx, label idx)
gather_indices = list(enumerate(grad_idxs))
# <tf.float32>[batch_size]
return tf.gather_nd(scores, gather_indices), grad_idxs
##
# LIT API implementation
def max_minibatch_size(self):
return self.config.inference_batch_size
def get_embedding_table(self):
# TODO(b/236276775): Unify on the TFBertEmbeddings.weight API after
# transformers is updated to v4.25.1 (or newer).
if hasattr(self.model.bert.embeddings, "word_embeddings"):
return self.vocab, self.model.bert.embeddings.word_embeddings.numpy()
else:
return self.vocab, self.model.bert.embeddings.weight.numpy()
def predict_minibatch(self, inputs: Iterable[JsonDict]):
# Use watch_accessed_variables to save memory by having the tape do nothing
# if we don't need gradients.
with tf.GradientTape(
watch_accessed_variables=self.config.compute_grads
) as tape:
encoded_input = self._preprocess(inputs)
# Gathers word embeddings from BERT model embedding layer using input ids
# of the tokens.
input_ids = encoded_input["input_ids"]
word_embeddings = self.model.bert.embeddings.weight
# <tf.float32>[batch_size, num_tokens, emb_size]
input_embs = tf.gather(word_embeddings, input_ids)
# Scatter in any passed in embeddings.
# <tf.float32>[batch_size, num_tokens, emb_size]
input_embs = self.scatter_all_embeddings(inputs, input_embs)
tape.watch(input_embs) # Watch input_embs for gradient calculation.
model_inputs = encoded_input.copy()
model_inputs["input_ids"] = None
out: TFSequenceClassifierOutput = self.model(
model_inputs,
inputs_embeds=input_embs,
training=False,
output_hidden_states=True,
output_attentions=True,
return_dict=True,
)
batched_outputs = {
"input_ids": encoded_input["input_ids"],
"ntok": tf.reduce_sum(encoded_input["attention_mask"], axis=1),
"cls_emb": out.hidden_states[-1][:, 0], # last layer, first token
}
if self.config.output_embeddings:
batched_outputs["input_embs"] = input_embs
self._verify_num_layers(out.hidden_states)
# <float32>[batch_size, num_tokens, 1]
token_mask = tf.expand_dims(
tf.cast(encoded_input["attention_mask"], tf.float32), axis=2
)
# <float32>[batch_size, 1]
denom = tf.reduce_sum(token_mask, axis=1)
for i, layer_output in enumerate(out.hidden_states):
# layer_output is <float32>[batch_size, num_tokens, emb_dim]
# average over tokens to get <float32>[batch_size, emb_dim]
batched_outputs[f"layer_{i}/avg_emb"] = (
tf.reduce_sum(layer_output * token_mask, axis=1) / denom
)
if self.config.output_attention:
if len(out.attentions) != self.model.config.num_hidden_layers:
raise ValueError(
"Unexpected size of attentions. Should be the same "
"size as the number of hidden layers. Expected "
f"{self.model.config.num_hidden_layers}, got "
f"{len(out.attentions)}."
)
for i, layer_attention in enumerate(out.attentions):
batched_outputs[f"layer_{i+1}/attention"] = layer_attention
if self.is_regression:
# <tf.float32>[batch_size]
batched_outputs["score"] = tf.squeeze(out.logits, axis=-1)
# <tf.float32>[batch_size], a single target per example
scalar_targets = batched_outputs["score"]
else:
# <tf.float32>[batch_size, num_labels]
batched_outputs["probas"] = tf.nn.softmax(out.logits, axis=-1)
# <tf.float32>[batch_size], a single target per example
scalar_targets, grad_idxs = self.get_target_scores(
inputs, batched_outputs["probas"]
)
# TODO(b/294613507): remove once TCAV updated.
if self.config.compute_grads:
batched_outputs[self.config.label_name] = tf.convert_to_tensor(
grad_idxs
)
# Request gradients after the tape is run.
# Note: embs[0] includes position and segment encodings, as well as subword
# embeddings.
if self.config.compute_grads:
# <tf.float32>[batch_size, num_tokens, emb_dim]
batched_outputs["input_emb_grad"] = tape.gradient(
scalar_targets, input_embs
)
detached_outputs = {
k: v.numpy()
for k, v in batched_outputs.items()
if v is not None
}
# Sequence of dicts, one per example.
unbatched_outputs = utils.unbatch_preds(detached_outputs)
return map(self._postprocess, unbatched_outputs)
def input_spec(self) -> Spec:
ret = {}
ret[self.config.text_a_name] = lit_types.TextSegment()
ret["tokens_" + self.config.text_a_name] = lit_types.Tokens(
parent=self.config.text_a_name, required=False
)
if self.config.text_b_name:
ret[self.config.text_b_name] = lit_types.TextSegment()
ret["tokens_" + self.config.text_b_name] = lit_types.Tokens(
parent=self.config.text_b_name, required=False
)
if self.is_regression:
ret[self.config.label_name] = lit_types.Scalar(required=False)
else:
ret[self.config.label_name] = lit_types.CategoryLabel(
required=False, vocab=self.config.labels
)
if self.config.output_embeddings:
# The input_embs_ fields are used for Integrated Gradients.
text_a_embs = "input_embs_" + self.config.text_a_name
ret[text_a_embs] = lit_types.TokenEmbeddings(
align="tokens", required=False
)
if self.config.text_b_name:
text_b_embs = "input_embs_" + self.config.text_b_name
ret[text_b_embs] = lit_types.TokenEmbeddings(
align="tokens", required=False
)
return ret
def output_spec(self) -> Spec:
ret = {"tokens": lit_types.Tokens()}
ret["tokens_" + self.config.text_a_name] = lit_types.Tokens(
parent=self.config.text_a_name
)
if self.config.text_b_name:
ret["tokens_" + self.config.text_b_name] = lit_types.Tokens(
parent=self.config.text_b_name
)
if self.is_regression:
ret["score"] = lit_types.RegressionScore(parent=self.config.label_name)
else:
ret["probas"] = lit_types.MulticlassPreds(
parent=self.config.label_name,
vocab=self.config.labels,
null_idx=self.config.null_label_idx,
)
if self.config.output_embeddings:
ret["cls_emb"] = lit_types.Embeddings()
# Average embeddings, one per layer including embeddings.
for i in range(1 + self.model.config.num_hidden_layers):
ret[f"layer_{i}/avg_emb"] = lit_types.Embeddings()
# The input_embs_ fields are used for Integrated Gradients.
ret["input_embs_" + self.config.text_a_name] = lit_types.TokenEmbeddings(
align="tokens_" + self.config.text_a_name
)
if self.config.text_b_name:
text_b_embs = "input_embs_" + self.config.text_b_name
ret[text_b_embs] = lit_types.TokenEmbeddings(
align="tokens_" + self.config.text_b_name
)
# Gradients, if requested.
if self.config.compute_grads:
ret["cls_grad"] = lit_types.Gradients(
align=("score" if self.is_regression else "probas"),
grad_for="cls_emb",
grad_target_field_key=self.config.label_name,
)
if not self.is_regression:
ret[self.config.label_name] = lit_types.CategoryLabel(
required=False, vocab=self.config.labels
)
if self.config.output_embeddings:
text_a_token_grads = "token_grad_" + self.config.text_a_name
ret[text_a_token_grads] = lit_types.TokenGradients(
align="tokens_" + self.config.text_a_name,
grad_for="input_embs_" + self.config.text_a_name,
grad_target_field_key=self.config.label_name,
)
if self.config.text_b_name:
text_b_token_grads = "token_grad_" + self.config.text_b_name
ret[text_b_token_grads] = lit_types.TokenGradients(
align="tokens_" + self.config.text_b_name,
grad_for="input_embs_" + self.config.text_b_name,
grad_target_field_key=self.config.label_name,
)
if self.config.output_attention:
# Attention heads, one field for each layer.
for i in range(self.model.config.num_hidden_layers):
ret[f"layer_{i+1}/attention"] = lit_types.AttentionHeads(
align_in="tokens", align_out="tokens"
)
return ret
class SST2Model(GlueModel):
"""Classification model on SST-2."""
def __init__(self, *args, **kw):
super().__init__(
*args,
text_a_name="sentence",
text_b_name=None,
labels=["0", "1"],
null_label_idx=0,
**kw,
)
class MNLIModel(GlueModel):
"""Classification model on MultiNLI."""
def __init__(self, *args, **kw):
super().__init__(
*args,
text_a_name="premise",
text_b_name="hypothesis",
labels=["entailment", "neutral", "contradiction"],
**kw,
)
class STSBModel(GlueModel):
"""Regression model on STS-B."""
def __init__(self, *args, **kw):
super().__init__(
*args,
text_a_name="sentence1",
text_b_name="sentence2",
labels=None,
**kw,
)
def input_spec(self):
ret = super().input_spec()
ret[self.config.label_name] = lit_types.Scalar(min_val=0, max_val=5)
return ret