|
19 | 19 |
|
20 | 20 | from __future__ import print_function |
21 | 21 |
|
| 22 | +import os |
| 23 | +import tempfile |
22 | 24 | from absl.testing import absltest |
23 | 25 | import apache_beam as beam |
24 | 26 | from apache_beam.testing import util |
25 | 27 | import numpy as np |
26 | 28 | import pyarrow as pa |
27 | 29 | from tensorflow_data_validation.api import stats_api |
28 | 30 | from tensorflow_data_validation.statistics import stats_options |
| 31 | +from tensorflow_data_validation.utils import io_util |
| 32 | +from tensorflow_data_validation.utils import stats_util |
29 | 33 | from tensorflow_data_validation.utils import test_util |
30 | 34 |
|
31 | 35 | from google.protobuf import text_format |
|
34 | 38 |
|
35 | 39 | class StatsAPITest(absltest.TestCase): |
36 | 40 |
|
| 41 | + def _get_temp_dir(self): |
| 42 | + return tempfile.mkdtemp() |
| 43 | + |
37 | 44 | def test_stats_pipeline(self): |
38 | 45 | record_batches = [ |
39 | 46 | pa.RecordBatch.from_arrays([ |
@@ -636,6 +643,42 @@ def test_invalid_stats_options(self): |
636 | 643 | p | beam.Create(record_batches) |
637 | 644 | | stats_api.GenerateStatistics(options={})) |
638 | 645 |
|
| 646 | + def test_write_stats_to_text(self): |
| 647 | + stats = text_format.Parse( |
| 648 | + """ |
| 649 | + datasets { |
| 650 | + name: 'x' |
| 651 | + num_examples: 100 |
| 652 | + } |
| 653 | + """, statistics_pb2.DatasetFeatureStatisticsList()) |
| 654 | + output_path = os.path.join(self._get_temp_dir(), 'stats') |
| 655 | + with beam.Pipeline() as p: |
| 656 | + _ = (p | beam.Create([stats]) | stats_api.WriteStatisticsToText( |
| 657 | + output_path)) |
| 658 | + stats_from_file = statistics_pb2.DatasetFeatureStatisticsList() |
| 659 | + serialized_stats = io_util.read_file_to_string( |
| 660 | + output_path, binary_mode=True) |
| 661 | + stats_from_file.ParseFromString(serialized_stats) |
| 662 | + self.assertLen(stats_from_file.datasets, 1) |
| 663 | + test_util.assert_dataset_feature_stats_proto_equal( |
| 664 | + self, stats_from_file.datasets[0], stats.datasets[0]) |
| 665 | + |
| 666 | + def test_write_stats_to_tfrecrod(self): |
| 667 | + stats = text_format.Parse( |
| 668 | + """ |
| 669 | + datasets { |
| 670 | + name: 'x' |
| 671 | + num_examples: 100 |
| 672 | + } |
| 673 | + """, statistics_pb2.DatasetFeatureStatisticsList()) |
| 674 | + output_path = os.path.join(self._get_temp_dir(), 'stats') |
| 675 | + with beam.Pipeline() as p: |
| 676 | + _ = (p | beam.Create([stats]) | stats_api.WriteStatisticsToTFRecord( |
| 677 | + output_path)) |
| 678 | + stats_from_file = stats_util.load_statistics(output_path) |
| 679 | + self.assertLen(stats_from_file.datasets, 1) |
| 680 | + test_util.assert_dataset_feature_stats_proto_equal( |
| 681 | + self, stats_from_file.datasets[0], stats.datasets[0]) |
639 | 682 |
|
640 | 683 | if __name__ == '__main__': |
641 | 684 | absltest.main() |
0 commit comments