Skip to content

Commit 38f3679

Browse files
paulgc17tfx-copybara
authored andcommitted
Add ptransforms to write statistics to text and tfrecord files.
PiperOrigin-RevId: 311405240
1 parent 8818451 commit 38f3679

File tree

8 files changed

+164
-35
lines changed

8 files changed

+164
-35
lines changed

RELEASE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
* Add utility methods `tfdv.get_slice_stats` to get statistics for a slice and
1515
`tfdv.compare_slices` to compare statistics of two slices using Facets.
1616
* Make `tfdv.load_stats_text` and `tfdv.write_stats_text` public.
17+
* Add PTransforms `tfdv.WriteStatisticsToText` and
18+
`tfdv.WriteStatisticsToTFRecord` to write statistics proto to text and
19+
tfrecord files respectively.
20+
* Modify `tfdv.load_statistics` to handle reading statistics from TFRecord and
21+
text files.
1722
* Requires `pyarrow>=0.16,<1`.
1823

1924
## Known Issues

tensorflow_data_validation/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
# Import stats API.
2424
from tensorflow_data_validation.api.stats_api import GenerateStatistics
25+
from tensorflow_data_validation.api.stats_api import WriteStatisticsToText
26+
from tensorflow_data_validation.api.stats_api import WriteStatisticsToTFRecord
2527

2628
# Import validation API.
2729
from tensorflow_data_validation.api.validation_api import infer_schema
@@ -69,10 +71,10 @@
6971
from tensorflow_data_validation.utils.stats_gen_lib import generate_statistics_from_csv
7072
from tensorflow_data_validation.utils.stats_gen_lib import generate_statistics_from_dataframe
7173
from tensorflow_data_validation.utils.stats_gen_lib import generate_statistics_from_tfrecord
72-
from tensorflow_data_validation.utils.stats_gen_lib import load_statistics
7374

7475
# Import stats utilities.
7576
from tensorflow_data_validation.utils.stats_util import get_slice_stats
77+
from tensorflow_data_validation.utils.stats_util import load_statistics
7678
from tensorflow_data_validation.utils.stats_util import load_stats_text
7779
from tensorflow_data_validation.utils.stats_util import write_stats_text
7880

tensorflow_data_validation/api/stats_api.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from tensorflow_data_validation import constants
5151
from tensorflow_data_validation.statistics import stats_impl
5252
from tensorflow_data_validation.statistics import stats_options
53-
from typing import Generator
53+
from typing import Generator, Text
5454

5555
from tensorflow_metadata.proto.v0 import statistics_pb2
5656

@@ -130,3 +130,48 @@ def _sample_at_rate(example: pa.RecordBatch, sample_rate: float
130130
# or add an optional seed argument.
131131
if random.random() <= sample_rate:
132132
yield example
133+
134+
135+
@beam.typehints.with_input_types(statistics_pb2.DatasetFeatureStatisticsList)
136+
@beam.typehints.with_output_types(beam.pvalue.PDone)
137+
class WriteStatisticsToText(beam.PTransform):
138+
"""API for writing serialized data statistics to text file."""
139+
140+
def __init__(self, output_path: Text) -> None:
141+
"""Initializes the transform.
142+
143+
Args:
144+
output_path: Output path for writing data statistics.
145+
"""
146+
self._output_path = output_path
147+
148+
def expand(self, stats: beam.pvalue.PCollection) -> beam.pvalue.PDone:
149+
return (stats
150+
| 'WriteStats' >> beam.io.WriteToText(
151+
self._output_path,
152+
shard_name_template='',
153+
append_trailing_newlines=False,
154+
coder=beam.coders.ProtoCoder(
155+
statistics_pb2.DatasetFeatureStatisticsList)))
156+
157+
158+
@beam.typehints.with_input_types(statistics_pb2.DatasetFeatureStatisticsList)
159+
@beam.typehints.with_output_types(beam.pvalue.PDone)
160+
class WriteStatisticsToTFRecord(beam.PTransform):
161+
"""API for writing serialized data statistics to TFRecord file."""
162+
163+
def __init__(self, output_path: Text) -> None:
164+
"""Initializes the transform.
165+
166+
Args:
167+
output_path: Output path for writing data statistics.
168+
"""
169+
self._output_path = output_path
170+
171+
def expand(self, stats: beam.pvalue.PCollection) -> beam.pvalue.PDone:
172+
return (stats
173+
| 'WriteStats' >> beam.io.WriteToTFRecord(
174+
self._output_path,
175+
shard_name_template='',
176+
coder=beam.coders.ProtoCoder(
177+
statistics_pb2.DatasetFeatureStatisticsList)))

tensorflow_data_validation/api/stats_api_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@
1919

2020
from __future__ import print_function
2121

22+
import os
23+
import tempfile
2224
from absl.testing import absltest
2325
import apache_beam as beam
2426
from apache_beam.testing import util
2527
import numpy as np
2628
import pyarrow as pa
2729
from tensorflow_data_validation.api import stats_api
2830
from tensorflow_data_validation.statistics import stats_options
31+
from tensorflow_data_validation.utils import io_util
32+
from tensorflow_data_validation.utils import stats_util
2933
from tensorflow_data_validation.utils import test_util
3034

3135
from google.protobuf import text_format
@@ -34,6 +38,9 @@
3438

3539
class StatsAPITest(absltest.TestCase):
3640

41+
def _get_temp_dir(self):
42+
return tempfile.mkdtemp()
43+
3744
def test_stats_pipeline(self):
3845
record_batches = [
3946
pa.RecordBatch.from_arrays([
@@ -636,6 +643,42 @@ def test_invalid_stats_options(self):
636643
p | beam.Create(record_batches)
637644
| stats_api.GenerateStatistics(options={}))
638645

646+
def test_write_stats_to_text(self):
647+
stats = text_format.Parse(
648+
"""
649+
datasets {
650+
name: 'x'
651+
num_examples: 100
652+
}
653+
""", statistics_pb2.DatasetFeatureStatisticsList())
654+
output_path = os.path.join(self._get_temp_dir(), 'stats')
655+
with beam.Pipeline() as p:
656+
_ = (p | beam.Create([stats]) | stats_api.WriteStatisticsToText(
657+
output_path))
658+
stats_from_file = statistics_pb2.DatasetFeatureStatisticsList()
659+
serialized_stats = io_util.read_file_to_string(
660+
output_path, binary_mode=True)
661+
stats_from_file.ParseFromString(serialized_stats)
662+
self.assertLen(stats_from_file.datasets, 1)
663+
test_util.assert_dataset_feature_stats_proto_equal(
664+
self, stats_from_file.datasets[0], stats.datasets[0])
665+
666+
def test_write_stats_to_tfrecrod(self):
667+
stats = text_format.Parse(
668+
"""
669+
datasets {
670+
name: 'x'
671+
num_examples: 100
672+
}
673+
""", statistics_pb2.DatasetFeatureStatisticsList())
674+
output_path = os.path.join(self._get_temp_dir(), 'stats')
675+
with beam.Pipeline() as p:
676+
_ = (p | beam.Create([stats]) | stats_api.WriteStatisticsToTFRecord(
677+
output_path))
678+
stats_from_file = stats_util.load_statistics(output_path)
679+
self.assertLen(stats_from_file.datasets, 1)
680+
test_util.assert_dataset_feature_stats_proto_equal(
681+
self, stats_from_file.datasets[0], stats.datasets[0])
639682

640683
if __name__ == '__main__':
641684
absltest.main()

tensorflow_data_validation/utils/stats_gen_lib.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tensorflow_data_validation.statistics import stats_impl
4343
from tensorflow_data_validation.statistics import stats_options as options
4444
from tensorflow_data_validation.statistics.generators import stats_generator
45+
from tensorflow_data_validation.utils import stats_util
4546
from tfx_bsl.arrow import array_util
4647
from typing import Any, List, Optional, Text
4748

@@ -120,12 +121,9 @@ def generate_statistics_from_tfrecord(
120121
desired_batch_size=batch_size)
121122
| 'GenerateStatistics' >> stats_api.GenerateStatistics(stats_options)
122123
# TODO(b/112014711) Implement a custom sink to write the stats proto.
123-
| 'WriteStatsOutput' >> beam.io.WriteToTFRecord(
124-
output_path,
125-
shard_name_template='',
126-
coder=beam.coders.ProtoCoder(
127-
statistics_pb2.DatasetFeatureStatisticsList)))
128-
return load_statistics(output_path)
124+
| 'WriteStatsOutput' >> stats_api.WriteStatisticsToTFRecord(
125+
output_path))
126+
return stats_util.load_statistics(output_path)
129127

130128

131129
def generate_statistics_from_csv(
@@ -204,12 +202,9 @@ def generate_statistics_from_csv(
204202
desired_batch_size=batch_size)
205203
| 'GenerateStatistics' >> stats_api.GenerateStatistics(stats_options)
206204
# TODO(b/112014711) Implement a custom sink to write the stats proto.
207-
| 'WriteStatsOutput' >> beam.io.WriteToTFRecord(
208-
output_path,
209-
shard_name_template='',
210-
coder=beam.coders.ProtoCoder(
211-
statistics_pb2.DatasetFeatureStatisticsList)))
212-
return load_statistics(output_path)
205+
| 'WriteStatsOutput' >> stats_api.WriteStatisticsToTFRecord(
206+
output_path))
207+
return stats_util.load_statistics(output_path)
213208

214209

215210
def generate_statistics_from_dataframe(
@@ -348,19 +343,3 @@ def get_csv_header(data_location: Text,
348343
'Found empty file when reading the header line: %s' % filename)
349344

350345
return result
351-
352-
353-
def load_statistics(
354-
input_path: Text) -> statistics_pb2.DatasetFeatureStatisticsList:
355-
"""Loads data statistics proto from file.
356-
357-
Args:
358-
input_path: Data statistics file path.
359-
360-
Returns:
361-
A DatasetFeatureStatisticsList proto.
362-
"""
363-
serialized_stats = next(tf.compat.v1.io.tf_record_iterator(input_path))
364-
result = statistics_pb2.DatasetFeatureStatisticsList()
365-
result.ParseFromString(serialized_stats)
366-
return result

tensorflow_data_validation/utils/stats_util.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
from __future__ import print_function
2020

21+
import logging
2122
import numpy as np
2223
import pyarrow as pa
24+
import tensorflow as tf
2325
from tensorflow_data_validation import types
2426
from tensorflow_data_validation.arrow import arrow_util
2527
from tensorflow_data_validation.utils import io_util
@@ -212,6 +214,22 @@ def load_stats_text(
212214
return stats_proto
213215

214216

217+
def load_stats_tfrecord(
218+
input_path: Text) -> statistics_pb2.DatasetFeatureStatisticsList:
219+
"""Loads data statistics proto from TFRecord file.
220+
221+
Args:
222+
input_path: Data statistics file path.
223+
224+
Returns:
225+
A DatasetFeatureStatisticsList proto.
226+
"""
227+
serialized_stats = next(tf.compat.v1.io.tf_record_iterator(input_path))
228+
result = statistics_pb2.DatasetFeatureStatisticsList()
229+
result.ParseFromString(serialized_stats)
230+
return result
231+
232+
215233
def get_feature_stats(stats: statistics_pb2.DatasetFeatureStatistics,
216234
feature_path: types.FeaturePath
217235
) -> statistics_pb2.FeatureNameStatistics:
@@ -295,3 +313,27 @@ def get_slice_stats(statistics: statistics_pb2.DatasetFeatureStatisticsList,
295313
result.datasets.add().CopyFrom(slice_stats)
296314
return result
297315
raise ValueError('Invalid slice key.')
316+
317+
318+
def load_statistics(
319+
input_path: Text) -> statistics_pb2.DatasetFeatureStatisticsList:
320+
"""Loads data statistics proto from file.
321+
322+
Args:
323+
input_path: Data statistics file path. The file should be a one-record
324+
TFRecord file or a plain file containing the serialized statistics proto.
325+
326+
Returns:
327+
A DatasetFeatureStatisticsList proto.
328+
329+
Raises:
330+
IOError: If the input path does not exist.
331+
"""
332+
if not tf.io.gfile.exists(input_path):
333+
raise IOError('Invalid input path {}.'.format(input_path))
334+
try:
335+
return load_stats_tfrecord(input_path)
336+
except Exception: # pylint: disable=broad-except
337+
logging.info('File %s did not look like a TFRecord. Try reading as a plain '
338+
'file.', input_path)
339+
return load_stats_text(input_path)

tensorflow_data_validation/utils/stats_util_test.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from absl import flags
2323
from absl.testing import absltest
2424
import numpy as np
25+
import tensorflow as tf
2526
from tensorflow_data_validation import types
2627
from tensorflow_data_validation.utils import stats_util
2728

@@ -129,12 +130,23 @@ def test_get_utf8(self):
129130

130131
def test_write_load_stats_text(self):
131132
stats = text_format.Parse("""
132-
datasets {}
133+
datasets { name: 'abc' }
133134
""", statistics_pb2.DatasetFeatureStatisticsList())
134135
stats_path = os.path.join(FLAGS.test_tmpdir, 'stats.pbtxt')
135136
stats_util.write_stats_text(stats=stats, output_path=stats_path)
136-
loaded_stats = stats_util.load_stats_text(input_path=stats_path)
137-
self.assertEqual(stats, loaded_stats)
137+
self.assertEqual(stats, stats_util.load_stats_text(input_path=stats_path))
138+
self.assertEqual(stats, stats_util.load_statistics(input_path=stats_path))
139+
140+
def test_load_stats_tfrecord(self):
141+
stats = text_format.Parse("""
142+
datasets { name: 'abc' }
143+
""", statistics_pb2.DatasetFeatureStatisticsList())
144+
stats_path = os.path.join(FLAGS.test_tmpdir, 'stats.tfrecord')
145+
with tf.io.TFRecordWriter(stats_path) as writer:
146+
writer.write(stats.SerializeToString())
147+
self.assertEqual(stats,
148+
stats_util.load_stats_tfrecord(input_path=stats_path))
149+
self.assertEqual(stats, stats_util.load_statistics(input_path=stats_path))
138150

139151
def test_write_stats_text_invalid_stats_input(self):
140152
with self.assertRaisesRegexp(

tensorflow_data_validation/utils/validation_lib.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorflow_data_validation.statistics import stats_impl
3232
from tensorflow_data_validation.statistics import stats_options as options
3333
from tensorflow_data_validation.utils import stats_gen_lib
34+
from tensorflow_data_validation.utils import stats_util
3435
from typing import List, Optional, Text
3536

3637
from tensorflow_metadata.proto.v0 import statistics_pb2
@@ -104,7 +105,7 @@ def validate_examples_in_tfrecord(
104105
coder=beam.coders.ProtoCoder(
105106
statistics_pb2.DatasetFeatureStatisticsList)))
106107

107-
return stats_gen_lib.load_statistics(output_path)
108+
return stats_util.load_statistics(output_path)
108109

109110

110111
def validate_examples_in_csv(
@@ -193,4 +194,4 @@ def validate_examples_in_csv(
193194
coder=beam.coders.ProtoCoder(
194195
statistics_pb2.DatasetFeatureStatisticsList)))
195196

196-
return stats_gen_lib.load_statistics(output_path)
197+
return stats_util.load_statistics(output_path)

0 commit comments

Comments
 (0)