-
Notifications
You must be signed in to change notification settings - Fork 45.6k
/
Copy pathinput_reader_builder.py
91 lines (77 loc) · 3.8 KB
/
input_reader_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input reader builder.
Creates data sources for DetectionModels from an InputReader config. See
input_reader.proto for options.
Note: If users wishes to also use their own InputReaders with the Object
Detection configuration framework, they should define their own builder function
that wraps the build function.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tf_slim as slim
from object_detection.data_decoders import tf_example_decoder
from object_detection.data_decoders import tf_sequence_example_decoder
from object_detection.protos import input_reader_pb2
parallel_reader = slim.parallel_reader
def build(input_reader_config):
"""Builds a tensor dictionary based on the InputReader config.
Args:
input_reader_config: A input_reader_pb2.InputReader object.
Returns:
A tensor dict based on the input_reader_config.
Raises:
ValueError: On invalid input reader proto.
ValueError: If no input paths are specified.
"""
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type '
'input_reader_pb2.InputReader.')
if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
config = input_reader_config.tf_record_input_reader
if not config.input_path:
raise ValueError('At least one input path must be specified in '
'`input_reader_config`.')
_, string_tensor = parallel_reader.parallel_read(
config.input_path[:], # Convert `RepeatedScalarContainer` to list.
reader_class=tf.TFRecordReader,
num_epochs=(input_reader_config.num_epochs
if input_reader_config.num_epochs else None),
num_readers=input_reader_config.num_readers,
shuffle=input_reader_config.shuffle,
dtypes=[tf.string, tf.string],
capacity=input_reader_config.queue_capacity,
min_after_dequeue=input_reader_config.min_after_dequeue)
label_map_proto_file = None
if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path
input_type = input_reader_config.input_type
if input_type == input_reader_pb2.InputType.Value('TF_EXAMPLE'):
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features)
return decoder.decode(string_tensor)
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features,
load_context_image_ids=input_reader_config.load_context_image_ids)
return decoder.decode(string_tensor)
raise ValueError('Unsupported input_type.')
raise ValueError('Unsupported input_reader_config.')