-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_helpers.py
49 lines (47 loc) · 1.54 KB
/
_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
from typing import Any
from onnx import helper, TensorProto
def np_dtype_to_tensor_dtype(dtype: Any):
"""
Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`.
"""
try:
dt = helper.np_dtype_to_tensor_dtype(dtype)
except (KeyError, ValueError):
if dtype == np.float32:
dt = TensorProto.FLOAT
elif dtype == np.float64:
dt = TensorProto.DOUBLE
elif dtype == np.int64:
dt = TensorProto.INT64
elif dtype == np.int32:
dt = TensorProto.INT32
elif dtype == np.int16:
dt = TensorProto.INT16
elif dtype == np.int8:
dt = TensorProto.INT8
elif dtype == np.uint64:
dt = TensorProto.UINT64
elif dtype == np.uint32:
dt = TensorProto.UINT32
elif dtype == np.uint16:
dt = TensorProto.UINT16
elif dtype == np.uint8:
dt = TensorProto.UINT8
elif dtype == np.float16:
dt = TensorProto.FLOAT16
elif dtype in (bool, np.bool_):
dt = TensorProto.BOOL
elif dtype in (str, np.str_):
dt = TensorProto.STRING
elif dtype is int:
dt = TensorProto.INT64
elif dtype is float:
dt = TensorProto.DOUBLE
elif dtype == np.complex64:
dt = TensorProto.COMPLEX64
elif dtype == np.complex128:
dt = TensorProto.COMPLEX128
else:
raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904
return dt