-
Notifications
You must be signed in to change notification settings - Fork 359
/
Copy pathremote_model.py
109 lines (91 loc) · 3.26 KB
/
remote_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Client code for querying remote models hosted by a LIT server."""
from typing import Any, Optional
import urllib
from absl import logging
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.lib import serialize
import requests
import six
urlopen = urllib.urlopen
JsonDict = types.JsonDict
def query_lit_server(
url: str,
endpoint: str,
params: Optional[dict[str, str]] = None,
inputs: Optional[Any] = None,
config: Optional[Any] = None,
) -> Any:
"""Query a LIT server from Python."""
# Pack data for LIT request
data = {'inputs': inputs, 'config': config}
# TODO(lit-dev): for open source, require HTTPS.
if not url.startswith('http://'):
url = 'http://' + url
full_url = urllib.parse.urljoin(url, endpoint)
# Use requests to handle URL params.
rq = requests.Request(
'POST',
full_url,
params=params,
data=serialize.to_json(data),
headers={'Content-Type': 'application/json'})
rq = rq.prepare()
# Convert to urllib request
request = urllib.request.Request(
url=rq.url,
data=six.ensure_binary(rq.body) if rq.body else None,
headers=rq.headers,
method=rq.method)
response = urlopen(request)
if response.code != 200:
raise IOError(f'Failed to query {rq.url}; response code {response.code}')
# TODO(iftenney): handle bad server response, e.g. if corplogin is required
# and the server sends a login page instead of a JSON response.
response_bytes = response.read()
return serialize.from_json(six.ensure_text(response_bytes))
class RemoteModel(lit_model.BatchedModel):
"""LIT model backed by a remote LIT server."""
def __init__(self, url: str, name: str, max_minibatch_size: int = 256):
"""Initialize model wrapper from remote server.
Args:
url: url of LIT server
name: name of model on the remote server
max_minibatch_size: batch size used for remote requests
"""
self._url = url
self._name = name
# Get specs
server_info = query_lit_server(self._url, 'get_info')
model_spec = server_info['models'][self._name]['spec']
self._input_spec = model_spec['input']
self._output_spec = model_spec['output']
self._max_minibatch_size = max_minibatch_size
def input_spec(self):
return self._input_spec
def output_spec(self):
return self._output_spec
def max_minibatch_size(self):
return self._max_minibatch_size
def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
# Package data as IndexedInput with dummy ids.
indexed_inputs = [{'id': None, 'data': d} for d in inputs]
logging.info('Querying remote model: /get_preds on %d examples',
len(indexed_inputs))
preds = query_lit_server(
self._url,
'get_preds',
params={
'model': self._name,
'response_simple_json': False
},
inputs=indexed_inputs)
logging.info('Received %d predictions from remote model.', len(preds))
return preds
def models_from_server(url: str, **model_kw) -> dict[str, RemoteModel]:
"""Wrap all the models on a given LIT server."""
server_info = query_lit_server(url, 'get_info')
models = {}
for name in server_info['models']:
models[name] = RemoteModel(url, name, **model_kw)
return models