-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathmodule_v2.py
128 lines (105 loc) · 5.13 KB
/
module_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Hub Module API for Tensorflow 2.0."""
import os
import tensorflow as tf
from tensorflow_hub import registry
_MODULE_PROTO_FILENAME_PB = "tfhub_module.pb"
def _get_module_proto_path(module_dir):
return os.path.join(
tf.compat.as_bytes(module_dir),
tf.compat.as_bytes(_MODULE_PROTO_FILENAME_PB))
def resolve(handle):
"""Resolves a module handle into a path.
This function works both for plain TF2 SavedModels and the legacy TF1 Hub
format.
Resolves a module handle into a path by downloading and caching in
location specified by TFHUB_CACHE_DIR if needed.
Currently, three types of module handles are supported:
1) Smart URL resolvers such as tfhub.dev, e.g.:
https://tfhub.dev/google/nnlm-en-dim128/1.
2) A directory on a file system supported by Tensorflow containing module
files. This may include a local directory (e.g. /usr/local/mymodule) or a
Google Cloud Storage bucket (gs://mymodule).
3) A URL pointing to a TGZ archive of a module, e.g.
https://example.com/mymodule.tar.gz.
Args:
handle: (string) the Module handle to resolve.
Returns:
A string representing the Module path.
"""
return registry.resolver(handle)
def load(handle, tags=None, options=None):
"""Resolves a handle and loads the resulting module.
This is the preferred API to load a Hub module in low-level TensorFlow 2.
Users of higher-level frameworks like Keras should use the framework's
corresponding wrapper, like hub.KerasLayer.
This function is roughly equivalent to the TF2 function
`tf.saved_model.load()` on the result of `hub.resolve(handle)`. Calling this
function requires TF 1.14 or newer. It can be called both in eager and graph
mode.
Note: Using in a tf.compat.v1.Session with variables placed on parameter
servers requires setting `experimental.share_cluster_devices_in_session`
within the `tf.compat.v1.ConfigProto`. (It becomes non-experimental in TF2.2.)
This function can handle the deprecated TF1 Hub format to the extent
that `tf.saved_model.load()` in TF2 does. In particular, the returned object
has attributes
* `.variables`: a list of variables from the loaded object;
* `.signatures`: a dict of TF2 ConcreteFunctions, keyed by signature names,
that take tensor kwargs and return a tensor dict.
However, the information imported by hub.Module into the collections of a
tf.Graph is lost (e.g., regularization losses and update ops).
Args:
handle: (string) the Module handle to resolve; see hub.resolve().
tags: A set of strings specifying the graph variant to use, if loading from
a v1 module.
options: Optional, `tf.saved_model.LoadOptions` object that specifies
options for loading. This argument can only be used from TensorFlow 2.3
onwards.
Returns:
A trackable object (see tf.saved_model.load() documentation for details).
Raises:
NotImplementedError: If the code is running against incompatible (1.x)
version of TF.
"""
if not isinstance(handle, str):
raise ValueError("Expected a string, got %s" % handle)
module_path = resolve(handle)
is_hub_module_v1 = tf.io.gfile.exists(_get_module_proto_path(module_path))
if tags is None and is_hub_module_v1:
tags = []
saved_model_path = os.path.join(
tf.compat.as_bytes(module_path),
tf.compat.as_bytes(tf.saved_model.SAVED_MODEL_FILENAME_PB))
saved_model_pbtxt_path = os.path.join(
tf.compat.as_bytes(module_path),
tf.compat.as_bytes(tf.saved_model.SAVED_MODEL_FILENAME_PBTXT))
if (not tf.io.gfile.exists(saved_model_path) and
not tf.io.gfile.exists(saved_model_pbtxt_path)):
raise ValueError("Trying to load a model of incompatible/unknown type. "
"'%s' contains neither '%s' nor '%s'." %
(module_path, tf.saved_model.SAVED_MODEL_FILENAME_PB,
tf.saved_model.SAVED_MODEL_FILENAME_PBTXT))
if options:
if not hasattr(getattr(tf, "saved_model", None), "LoadOptions"):
raise NotImplementedError("options are not supported for TF < 2.3.x,"
" Current version: %s" % tf.__version__)
# tf.compat.v1.saved_model.load_v2() is TF2 tf.saved_model.load() before TF2
obj = tf.compat.v1.saved_model.load_v2(
module_path, tags=tags, options=options)
else:
obj = tf.compat.v1.saved_model.load_v2(module_path, tags=tags)
obj._is_hub_module_v1 = is_hub_module_v1 # pylint: disable=protected-access
return obj