-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathfeature_test_case.py
498 lines (442 loc) · 16.9 KB
/
feature_test_case.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
# coding=utf-8
# Copyright 2023 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test case util to test `tfds.features.FeatureConnector`."""
import contextlib
import dataclasses
import functools
from typing import Any, Optional, Type
import dill
from etils import enp
import numpy as np
from tensorflow_datasets.core import dataset_utils
from tensorflow_datasets.core import features
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.utils import tree_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
from tensorflow_datasets.testing import test_case
from tensorflow_datasets.testing import test_utils
@dataclasses.dataclass
class FeatureExpectationItem:
"""Test item of a FeatureExpectation.
Should be passed to `assertFeature` method of `FeatureExpectationsTestCase`.
Each `FeatureExpectationItem` test an example serialization/deserialization (
`feature.encode_example` -> `example_serializer.serialize_example` ->
`example_parse.parse_example` -> `feature.decode_example`).
Attributes:
value: Input to `features.encode_example`
expected: Expected output after `features.decode_example`
expected_np: Expected output after `features.decode_example_np`
expected_serialized: Optional
decoders: Optional `tfds.decode.Decoder` objects (to overwrite the default
`features.decode_example`). See
https://www.tensorflow.org/datasets/decode.
dtype: If `decoders` is provided, the output of `decode_example` is checked
against this value (otherwise, output is checked against `features.dtype`)
shape: If `decoders` is provided, the output of `decode_example` is checked
against this value (otherwise, output is checked against `features.shape`)
raise_cls: Expected error raised during `features.encode_example`. When set
`expected` and `expected_serialized` should be `None`.
raise_cls_np: Expected error raised during `features.encode_example_np`.
When set `expected_np` and `expected_serialized_np` should be `None`.
raise_msg: Expected error message regex.
atol: If provided, compare float values with this precision (use default
otherwise).
"""
value: Any
expected: Optional[Any] = None
expected_np: Optional[np.ndarray] = None
expected_serialized: Optional[Any] = None
decoders: Optional[utils.TreeDict[Any]] = None
dtype: Optional[tf.dtypes.DType] = None
shape: Optional[utils.Shape] = None
raise_cls: Optional[Type[Exception]] = None
raise_cls_np: Optional[Type[Exception]] = None
raise_msg: Optional[str] = None
atol: Optional[float] = None
def __post_init__(self):
if not self.decoders and (self.dtype is not None or self.shape is not None):
raise ValueError('dtype and shape should only be set with transform')
class SubTestCase(test_case.TestCase):
"""Adds subTest() context manager to the TestCase if supported.
Note: To use this feature, make sure you call super() in setUpClass to
initialize the sub stack.
"""
@classmethod
def setUpClass(cls):
super(SubTestCase, cls).setUpClass()
cls._sub_test_stack = []
@contextlib.contextmanager
def _subTest(self, test_str):
self._sub_test_stack.append(test_str)
sub_test_str = '/'.join(self._sub_test_stack)
with self.subTest(sub_test_str):
yield
self._sub_test_stack.pop()
def assertAllEqualNested(self, d1, d2, *, atol: Optional[float] = None):
"""Same as assertAllEqual but compatible with nested dict.
Args:
d1: First element to compare
d2: Second element to compare
atol: If given, perform a close float comparison. Otherwise, perform an
exact comparison
"""
if isinstance(d1, dict):
# assertAllEqual do not works well with dictionaries so assert
# on each individual elements instead
zipped_examples = utils.zip_nested(d1, d2, dict_only=True)
utils.map_nested(
# recursively call assertAllEqualNested in case there is a dataset.
lambda x: self.assertAllEqualNested(x[0], x[1], atol=atol),
zipped_examples,
dict_only=True,
)
elif isinstance(d1, (tf.data.Dataset, dataset_utils._IterableDataset)): # pylint: disable=protected-access
# Checks length and elements of the dataset. At the moment, more than one
# level of nested datasets is not supported.
self.assertEqual(len(d1), len(d2))
for ex1, ex2 in zip(d1, d2):
self.assertAllEqualNested(ex1, ex2, atol=atol)
elif atol:
self.assertAllClose(d1, d2, atol=atol)
else:
self.assertAllEqual(d1, d2)
class RaggedConstant(object):
"""Container of tf.ragged.constant values.
This simple wrapper forward the arguments to delay the RaggedTensor
construction after `@run_in_graph_and_eager_modes` has been called.
This is required to avoid incompabilities between Graph/eager.
"""
def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = dict(kwargs)
def build(self):
return tf.ragged.constant(*self._args, **self._kwargs)
class FeatureExpectationsTestCase(SubTestCase):
"""Tests FeatureExpectations with full encode-decode."""
@test_utils.run_in_graph_and_eager_modes()
def assertFeature(
self,
feature,
shape,
dtype,
tests,
serialized_info=None,
# TODO(b/227584124): remove this parameter after fixing this bug
test_tensor_spec=True,
skip_feature_tests=False,
test_attributes=None,
):
"""Test the given feature against the predicates."""
self.assertFeatureEagerOnly(
feature=feature,
shape=shape,
dtype=dtype,
tests=tests,
serialized_info=serialized_info,
test_tensor_spec=test_tensor_spec,
skip_feature_tests=skip_feature_tests,
test_attributes=test_attributes,
)
def assertFeatureEagerOnly(
self,
feature,
shape,
dtype,
tests,
serialized_info=None,
test_tensor_spec=True,
skip_feature_tests=False,
test_attributes=None,
):
"""Test the given feature against the predicates."""
# Fill kwargs
run_tests = functools.partial(
self._run_tests,
shape=shape,
dtype=dtype,
tests=tests,
test_tensor_spec=test_tensor_spec,
)
assert_feature = functools.partial(
self._assert_feature,
shape=shape,
dtype=dtype,
tests=tests,
serialized_info=serialized_info,
test_tensor_spec=test_tensor_spec,
skip_feature_tests=skip_feature_tests,
test_attributes=test_attributes,
)
# Create the feature dict
fdict = features.FeaturesDict({'inner': feature})
# Check whether the following doesn't raise an exception
fdict.catalog_documentation()
with self._subTest('feature'):
assert_feature(feature=feature)
run_tests(serialize_fdict=fdict, deserialize_fdict=fdict, feature=feature)
# Test the feature again to make sure it behave correctly after restoring
# TODO(tfds): Remove `skip_feature_tests` after text encoders are removed
if not skip_feature_tests:
# Restored from config
with test_utils.tmp_dir() as config_dir:
feature.save_config(config_dir)
new_feature = feature.from_config(config_dir)
assert_feature(feature=new_feature)
with self._subTest('feature_roundtrip'):
run_tests(
serialize_fdict=fdict, deserialize_fdict=fdict, feature=new_feature
)
# Restored from proto
with test_utils.tmp_dir() as config_dir:
feature_proto = feature.to_proto()
feature.save_metadata(config_dir, feature_name=None)
new_feature = feature_lib.FeatureConnector.from_proto(feature_proto)
new_feature.load_metadata(config_dir, feature_name=None)
assert_feature(feature=new_feature)
new_fdict = features.FeaturesDict({'inner': new_feature})
with self._subTest('feature_proto_roundtrip'):
run_tests(
serialize_fdict=fdict, deserialize_fdict=fdict, feature=new_feature
)
with self._subTest('serialize_fdict'):
run_tests(
serialize_fdict=new_fdict,
deserialize_fdict=fdict,
feature=new_feature,
)
with self._subTest('deserialize_fdict'):
run_tests(
serialize_fdict=fdict,
deserialize_fdict=new_fdict,
feature=new_feature,
)
def _assert_feature(
self,
feature,
shape,
dtype,
tests,
serialized_info=None,
test_tensor_spec=True,
skip_feature_tests=False,
test_attributes=None,
):
with self._subTest('shape'):
self.assertEqual(feature.shape, shape)
with self._subTest('dtype'):
self.assertEqual(feature.dtype, dtype)
tree_utils.map_structure(enp.lazy.is_np_dtype, feature.np_dtype)
tree_utils.map_structure(enp.lazy.is_tf_dtype, feature.tf_dtype)
# Check the serialized features
if serialized_info:
with self._subTest('serialized_info'):
self.assertEqual(
serialized_info,
feature.get_serialized_info(),
)
if not skip_feature_tests and test_attributes:
for key, value in test_attributes.items():
self.assertEqual(getattr(feature, key), value)
def _run_tests(
self,
serialize_fdict,
deserialize_fdict,
feature,
shape,
dtype,
tests,
test_tensor_spec=True,
):
for i, test in enumerate(tests):
with self._subTest(str(i)):
self.assertFeatureTest(
serialize_fdict=serialize_fdict,
deserialize_fdict=deserialize_fdict,
test=test,
feature=feature,
shape=shape,
dtype=dtype,
test_tensor_spec=test_tensor_spec,
)
def assertFeatureTest(
self,
serialize_fdict,
deserialize_fdict,
test,
feature,
shape,
dtype,
test_tensor_spec: bool = True,
):
"""Test that encode=>decoding of a value works correctly."""
# test feature.encode_example can be pickled and unpickled for beam.
dill.loads(dill.dumps(feature.encode_example))
input_value = {'inner': test.value}
if test.raise_cls is not None or test.raise_cls_np is not None:
if test.raise_cls is not None:
with self._subTest('raise'):
if not test.raise_msg:
raise ValueError(
'test.raise_msg should be set with {} for test {}'.format(
test.raise_cls, type(feature)
)
)
with self.assertRaisesWithPredicateMatch(
test.raise_cls, test.raise_msg
):
features_encode_decode(
serialize_fdict,
deserialize_fdict,
input_value,
decoders=test.decoders,
)
if test.raise_cls_np is not None:
with self._subTest('raise_np'):
if not test.raise_msg:
raise ValueError(
'test.raise_msg should be set with {} for test {}'.format(
test.raise_cls_np, type(feature)
)
)
with self.assertRaisesWithPredicateMatch(
test.raise_cls_np, test.raise_msg
):
features_encode_decode_np(
serialize_fdict,
deserialize_fdict,
input_value,
decoders=test.decoders,
)
else:
# Test the serialization only
if test.expected_serialized is not None:
with self._subTest('out_serialize'):
self.assertEqual(
test.expected_serialized,
feature.encode_example(test.value),
)
# Test serialization + decoding from disk for NumPy worflow
if test.expected_np is not None:
with self._subTest('out_np'):
out_numpy = features_encode_decode_np(
serialize_fdict,
deserialize_fdict,
input_value,
decoders={'inner': test.decoders},
)
with self._subTest('out_np_value'):
np.testing.assert_array_equal(out_numpy['inner'], test.expected_np)
# Test serialization + decoding from disk
with self._subTest('out'):
out_tensor, out_numpy, out_element_spec = features_encode_decode(
serialize_fdict,
deserialize_fdict,
input_value,
decoders={'inner': test.decoders},
)
out_tensor = out_tensor['inner']
out_numpy = out_numpy['inner']
out_element_spec = out_element_spec['inner']
if test_tensor_spec:
with self._subTest('tensor_spec'):
assert feature.get_tensor_spec() == out_element_spec
# Assert the returned type match the expected one
with self._subTest('dtype'):
def _get_dtype(s):
if isinstance(s, tf.data.Dataset):
return tf.nest.map_structure(_get_dtype, s.element_spec)
else:
return s.dtype
out_dtypes = tf.nest.map_structure(_get_dtype, out_tensor)
self.assertEqual(out_dtypes, test.dtype or feature.dtype)
with self._subTest('shape'):
# For shape, because (None, 3) match with (5, 3), we use
# tf.TensorShape.assert_is_compatible_with on each of the elements
expected_shape = feature.shape if test.shape is None else test.shape
def _get_shape(s):
if isinstance(s, tf.data.Dataset):
return utils.map_nested(_get_shape, s.element_spec)
else:
return s.shape
out_shapes = utils.map_nested(_get_shape, out_tensor)
shapes_tuple = utils.zip_nested(out_shapes, expected_shape)
utils.map_nested(
lambda x: x[0].assert_is_compatible_with(x[1]), shapes_tuple
)
# Assert value
with self._subTest('out_value'):
# Eventually construct the tf.RaggedTensor
expected = tf.nest.map_structure(
lambda t: t.build() if isinstance(t, RaggedConstant) else t,
test.expected,
)
self.assertAllEqualNested(out_numpy, expected, atol=test.atol)
# Assert the HTML representation works
if not test.decoders:
with self._subTest('repr'):
self._test_repr(feature, out_numpy)
def _test_repr(
self,
feature: features.FeatureConnector,
out_numpy: np.ndarray,
) -> None:
"""Test that the HTML repr works."""
# pylint: disable=protected-access
flat_example = feature._flatten(out_numpy)
flat_features = feature._flatten(feature)
flat_serialized_info = feature._flatten(feature.get_serialized_info())
# pylint: enable=protected-access
for ex, f, spec in zip(flat_example, flat_features, flat_serialized_info):
# Features with multi-data not supported
if isinstance(spec, dict):
continue
# TODO(tfds): Should use `as_dataframe._get_feature` instead, to
# correctly compute for `sequence_rank` for subclasses like `Video`.
elif spec.sequence_rank == 0:
text = f.repr_html(ex)
elif spec.sequence_rank == 1:
text = f.repr_html_batch(ex)
elif spec.sequence_rank > 1:
text = f.repr_html_ragged(ex)
self.assertIsInstance(text, str)
def features_encode_decode(
serialize_fdict, deserialize_fdict, example, decoders
):
"""Runs the full pipeline: encode > write > tmp files > read > decode."""
# Serialize/deserialize the example using TensorFlow methods.
serialized_example = serialize_fdict.serialize_example(example)
decode_fn = functools.partial(
deserialize_fdict.deserialize_example,
decoders=decoders,
)
ds = tf.data.Dataset.from_tensors(serialized_example)
ds = ds.map(decode_fn)
if tf.executing_eagerly():
out_tensor = next(iter(ds))
else:
out_tensor = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
out_numpy = dataset_utils.as_numpy(out_tensor)
return out_tensor, out_numpy, ds.element_spec
def features_encode_decode_np(
serialize_fdict, deserialize_fdict, example, decoders
):
"""Runs the full pipeline: encode > write > tmp files > read > decode."""
# Serialize/deserialize the example using NumPy methods.
serialized_example = serialize_fdict.serialize_example(example)
deserialized_example = deserialize_fdict.deserialize_example_np(
serialized_example, decoders=decoders
)
return deserialized_example