-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathort_optimizers.py
49 lines (45 loc) · 1.58 KB
/
ort_optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Union, Optional
from onnx import ModelProto, load
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi._pybind_state import GraphOptimizationLevel
from ..cache import get_cache_file
def ort_optimized_model(
onx: Union[str, ModelProto],
level: str = "ORT_ENABLE_ALL",
output: Optional[str] = None,
) -> Union[str, ModelProto]:
"""
Returns the optimized model used by onnxruntime before
running computing the inference.
:param onx: ModelProto
:param level: optimization level, `'ORT_ENABLE_BASIC'`,
`'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'`
:param output: output file if the proposed cache is not wanted
:return: optimized model
"""
glevel = getattr(GraphOptimizationLevel, level, None)
if glevel is None:
raise ValueError(
f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}."
)
if output is not None:
cache = output
else:
cache = get_cache_file("ort_optimized_model.onnx", remove=True)
so = SessionOptions()
so.graph_optimization_level = glevel
so.optimized_model_filepath = str(cache)
InferenceSession(
onx if isinstance(onx, str) else onx.SerializeToString(),
so,
providers=["CPUExecutionProvider"],
)
if output is None and not cache.exists():
raise RuntimeError(f"The optimized model {str(cache)!r} not found.")
if output is not None:
return output
if isinstance(onx, str):
return str(cache)
opt_onx = load(str(cache))
cache.unlink()
return opt_onx